CondTransformer#
- class CondTransformer(in_dim, out_dim, cond_dim, in_horizon, cond_horizon, pos_encoding='additive_absolute', num_layers=6, num_attention_heads=6, embedding_dim=768, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0, norm_first=True, norm_module='rmsnorm', **pos_encoding_kwargs)[source]#
Bases:
GenCondRegressorEncoder/decoder Transformer with configurable positional attention.
This module is an encoder/decoder Transformer with num_cond_layers conditioning encoder layers and num_layers decoder layers, following the architecture introduced in Attention Is All You Need by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, and Polosukhin (NeurIPS 2017), with two task-specific changes:
the conditioning memory is built from an optimisation/transport-step token together with the conditioning sequence, and
the attention blocks support several positional encoding schemes.
The conditioning stream is assembled as
\[[k, \mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}],\]where \(k\) is the inference-time optimisation/transport step token and \(\mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}\) are the conditioning tokens ordered from oldest to most recent observation. The decoder predicts the target sequence
\[[\mathbf{x}_{0}, \ldots, \mathbf{x}_{T_x - 1}],\]ordered from the first action to the last action in the predicted horizon.
Supported positional encodings are:
"additive_absolute"Uses
AdditivePosMultiheadAttention. A learned table maps integer positions to additive updates that are added to the query and key streams before standard multi-head attention."additive_relative"Uses
AdditiveRelMultiheadAttention. A learned table maps relative token distances to additive score biases, yielding a time-translation-equivariant attention rule."rope"Uses
RoPEMultiheadAttention. Rotary position embeddings are applied per-head to the query and key projections, leaving values untouched."none"Uses
MultiheadAttentionwith no explicit positional encoding.
Temporal assumptions:
\(Z\) must already be ordered in time from past to present.
\(X\) must already be ordered from the first predicted action to the last predicted action.
For
"additive_relative"and"rope", the last conditioning token \(\mathbf{z}_{0}\) and the first action token \(\mathbf{x}_{0}\) are both placed at time index \(0\), so cross-attention is anchored at the present time.The optimisation-step token \(k\) is prepended to the conditioning memory, but it is not treated as part of the observation timeline.
The architecture is implemented as follows:
The inference-time optimisation step
kis sinusoidally embedded and prepended as the first conditioning token.The observed sequence \(Z\) is linearly embedded and concatenated after the step token.
When
num_cond_layers > 0the conditioning tokens pass through a positional-attention encoder; otherwise a lightweight MLP refines them.A positional-attention decoder attends from the input trajectory \(X_k\) to the conditioning memory.
- Parameters:
in_dim (
int) – Dimensionality of each element in \(X\).out_dim (
int) – Dimensionality of the regressed vector field.cond_dim (
int) – Dimensionality of each conditioning element in \(Z\).in_horizon (
int) – Maximum length of \(X\).cond_horizon (
int) – Maximum length of \(Z\) (excluding the optimisation-step token).pos_encoding (
str) – Positional encoding strategy:"additive_absolute","additive_relative","rope", or"none".num_layers (
int) – Number of Transformer decoder layers.num_attention_heads (
int) – Number of attention heads in Multi-Head Attention blocks.embedding_dim (
int) – Dimensionality of token embeddings.p_drop_emb (
float) – Dropout applied to embeddings.p_drop_attn (
float) – Dropout applied inside attention blocks.causal_attn (
bool) – Whether to use causal attention in self-attention and cross-attention layers.num_cond_layers (
int) – Number of encoder layers dedicated to conditioning tokens.norm_first (
bool) – Whether to apply normalization before each residual branch.norm_module (
str) – Final and per-layer normalization type:"layernorm"or"rmsnorm".
- configure_optimizers(learning_rate=0.0001, weight_decay=0.001, betas=(0.9, 0.95))[source]#
Create optimizer groups separating parameters that receive weight decay from those that don’t.
- forward(X, opt_step, Z)[source]#
Forward pass of the conditional transformer regressor, approximating V_k = f(X_k, Z, k).
- Parameters:
X (
Tensor) – The input/data sample composed of a trajectory of T_x points in a d_x-dimensional space. Shape: (B, T_x, d_x), where B is the batch size.opt_step (
Tensor|float|int) – The optimisation timestep(s) k at which to evaluate the regressor. Can be a single scalar or a tensor of shape (B,).Z (
Tensor) – The conditioning/observation variable composed of T_z points in a d_z-dimensional space. Shape: (B, T_z, d_z), where B is the batch size.
- Returns:
The output regression variable of shape (B, T_x, d_v).
- Return type: