CondTransformerRegressor#
- class CondTransformerRegressor(in_dim, out_dim, cond_dim, in_horizon, cond_horizon, num_layers=6, num_attention_heads=6, embedding_dim=768, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0)[source]#
Bases:
GenCondRegressorTransformer-based generative conditional regressor.
The module parameterizes \(\mathbf{f}_{\mathbf{\theta}}(X_k, Z, k)\) with a stack of Transformer blocks. The input trajectory \(X_k\) is first projected into an embedding space and interpreted as the target (tgt) sequence of a standard
torch.nn.TransformerDecoder. Conditioning information is packed into the decoder memory stream:The inference-time step k is mapped with a sinusoidal embedding and inserted as the first conditioning token.
The observed sequence \(Z\) is linearly embedded, receives learned positional encodings, and is appended after the step token.
When
n_cond_layers > 0the conditioning tokens are processed by a Transformer encoder so that the decoder attends to context-aware features; otherwise a lightweight MLP refines the embeddings.
During decoding, self-attention layers refine \(X_k\) internally while cross-attention layers pull information from the conditioning memory, enabling the model to fuse optimisation step, observations, and trajectory features at every layer.
The model map is:
\[\mathbf{f}_{\mathbf{\theta}}: \mathbb{R}^{d_x \times T_x} \times \mathbb{R}^{d_z \times T_z} \times \mathbb{R} \to \mathbb{R}^{d_v \times T_x}.\]This is an unconstrained baseline model (no explicit group-equivariance constraint).
- Parameters:
in_dim (
int) – Dimensionality of each element in \(X\).out_dim (
int) – Dimensionality of the regressed vector field.cond_dim (
int) – Dimensionality of each conditioning element in \(Z\).in_horizon (
int) – Maximum length of \(X\).cond_horizon (
int) – Maximum length of \(Z\) (excluding the optimization-step token).num_layers (
int) – Number of Transformer decoder layers.num_attention_heads (
int) – Number of attention heads in Multi-Head Attention blocks.embedding_dim (
int) – Dimensionality of token embeddings.p_drop_emb (
float) – Dropout applied to embeddings.p_drop_attn (
float) – Dropout applied inside attention blocks.causal_attn (
bool) – Whether to use causal attention in self-attention and cross-attention layers.num_cond_layers (
int) – Number of encoder layers dedicated to conditioning tokens.
- configure_optimizers(learning_rate=0.0001, weight_decay=0.001, betas=(0.9, 0.95))[source]#
Creates optimizer groups separating out parameters to apply weight decay to and those that don’t.
- forward(X, opt_step, Z)[source]#
Forward pass of the conditional transformer regressor, approximating V_k = f(X_k, Z, k).
- Parameters:
X (
Tensor) – The input/data sample composed of a trajectory of T_x points in a d_x-dimensional space. Shape: (B, T_x, d_x), where B is the batch size.opt_step (
Tensor|float|int) – The optimization timestep(s) k at which to evaluate the regressor. Can be a single scalar or a tensor of shape (B,).Z (
Tensor) – The conditioning/observation variable composed of T_z points in a d_z-dimensional space. Shape: (B, T_z, d_z), where B is the batch size.
- Returns:
The output regression variable of shape (B, T_x, d_v).
- Return type:
- get_optim_groups(weight_decay=0.001)[source]#
Creates optimizer groups separating out parameters to apply weight decay to and those that don’t.
This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won’t (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object.
- Parameters:
weight_decay (float)