CondTransformerRegressor#

class CondTransformerRegressor(in_dim, out_dim, cond_dim, in_horizon, cond_horizon, num_layers=6, num_attention_heads=6, embedding_dim=768, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0)[source]#

Bases: GenCondRegressor

Transformer-based generative conditional regressor.

The module parameterizes \(\mathbf{f}_{\mathbf{\theta}}(X_k, Z, k)\) with a stack of Transformer blocks. The input trajectory \(X_k\) is first projected into an embedding space and interpreted as the target (tgt) sequence of a standard torch.nn.TransformerDecoder. Conditioning information is packed into the decoder memory stream:

  • The inference-time step k is mapped with a sinusoidal embedding and inserted as the first conditioning token.

  • The observed sequence \(Z\) is linearly embedded, receives learned positional encodings, and is appended after the step token.

  • When n_cond_layers > 0 the conditioning tokens are processed by a Transformer encoder so that the decoder attends to context-aware features; otherwise a lightweight MLP refines the embeddings.

During decoding, self-attention layers refine \(X_k\) internally while cross-attention layers pull information from the conditioning memory, enabling the model to fuse optimisation step, observations, and trajectory features at every layer.

The model map is:

\[\mathbf{f}_{\mathbf{\theta}}: \mathbb{R}^{d_x \times T_x} \times \mathbb{R}^{d_z \times T_z} \times \mathbb{R} \to \mathbb{R}^{d_v \times T_x}.\]

This is an unconstrained baseline model (no explicit group-equivariance constraint).

Parameters:
  • in_dim (int) – Dimensionality of each element in \(X\).

  • out_dim (int) – Dimensionality of the regressed vector field.

  • cond_dim (int) – Dimensionality of each conditioning element in \(Z\).

  • in_horizon (int) – Maximum length of \(X\).

  • cond_horizon (int) – Maximum length of \(Z\) (excluding the optimization-step token).

  • num_layers (int) – Number of Transformer decoder layers.

  • num_attention_heads (int) – Number of attention heads in Multi-Head Attention blocks.

  • embedding_dim (int) – Dimensionality of token embeddings.

  • p_drop_emb (float) – Dropout applied to embeddings.

  • p_drop_attn (float) – Dropout applied inside attention blocks.

  • causal_attn (bool) – Whether to use causal attention in self-attention and cross-attention layers.

  • num_cond_layers (int) – Number of encoder layers dedicated to conditioning tokens.

configure_optimizers(learning_rate=0.0001, weight_decay=0.001, betas=(0.9, 0.95))[source]#

Creates optimizer groups separating out parameters to apply weight decay to and those that don’t.

Parameters:
forward(X, opt_step, Z)[source]#

Forward pass of the conditional transformer regressor, approximating V_k = f(X_k, Z, k).

Parameters:
  • X (Tensor) – The input/data sample composed of a trajectory of T_x points in a d_x-dimensional space. Shape: (B, T_x, d_x), where B is the batch size.

  • opt_step (Tensor | float | int) – The optimization timestep(s) k at which to evaluate the regressor. Can be a single scalar or a tensor of shape (B,).

  • Z (Tensor) – The conditioning/observation variable composed of T_z points in a d_z-dimensional space. Shape: (B, T_z, d_z), where B is the batch size.

Returns:

The output regression variable of shape (B, T_x, d_v).

Return type:

Tensor

get_optim_groups(weight_decay=0.001)[source]#

Creates optimizer groups separating out parameters to apply weight decay to and those that don’t.

This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won’t (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object.

Parameters:

weight_decay (float)