CondTransformer#

class CondTransformer(in_dim, out_dim, cond_dim, in_horizon, cond_horizon, pos_encoding='additive_absolute', num_layers=6, num_attention_heads=6, embedding_dim=768, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0, norm_first=True, norm_module='rmsnorm', **pos_encoding_kwargs)[source]#

Bases: GenCondRegressor

Encoder/decoder Transformer with configurable positional attention.

This module is an encoder/decoder Transformer with num_cond_layers conditioning encoder layers and num_layers decoder layers, following the architecture introduced in Attention Is All You Need by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, and Polosukhin (NeurIPS 2017), with two task-specific changes:

  1. the conditioning memory is built from an optimisation/transport-step token together with the conditioning sequence, and

  2. the attention blocks support several positional encoding schemes.

The conditioning stream is assembled as

\[[k, \mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}],\]

where \(k\) is the inference-time optimisation/transport step token and \(\mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}\) are the conditioning tokens ordered from oldest to most recent observation. The decoder predicts the target sequence

\[[\mathbf{x}_{0}, \ldots, \mathbf{x}_{T_x - 1}],\]

ordered from the first action to the last action in the predicted horizon.

Supported positional encodings are:

"additive_absolute"

Uses AdditivePosMultiheadAttention. A learned table maps integer positions to additive updates that are added to the query and key streams before standard multi-head attention.

"additive_relative"

Uses AdditiveRelMultiheadAttention. A learned table maps relative token distances to additive score biases, yielding a time-translation-equivariant attention rule.

"rope"

Uses RoPEMultiheadAttention. Rotary position embeddings are applied per-head to the query and key projections, leaving values untouched.

"none"

Uses MultiheadAttention with no explicit positional encoding.

Temporal assumptions:

  • \(Z\) must already be ordered in time from past to present.

  • \(X\) must already be ordered from the first predicted action to the last predicted action.

  • For "additive_relative" and "rope", the last conditioning token \(\mathbf{z}_{0}\) and the first action token \(\mathbf{x}_{0}\) are both placed at time index \(0\), so cross-attention is anchored at the present time.

  • The optimisation-step token \(k\) is prepended to the conditioning memory, but it is not treated as part of the observation timeline.

The architecture is implemented as follows:

  • The inference-time optimisation step k is sinusoidally embedded and prepended as the first conditioning token.

  • The observed sequence \(Z\) is linearly embedded and concatenated after the step token.

  • When num_cond_layers > 0 the conditioning tokens pass through a positional-attention encoder; otherwise a lightweight MLP refines them.

  • A positional-attention decoder attends from the input trajectory \(X_k\) to the conditioning memory.

Parameters:
  • in_dim (int) – Dimensionality of each element in \(X\).

  • out_dim (int) – Dimensionality of the regressed vector field.

  • cond_dim (int) – Dimensionality of each conditioning element in \(Z\).

  • in_horizon (int) – Maximum length of \(X\).

  • cond_horizon (int) – Maximum length of \(Z\) (excluding the optimisation-step token).

  • pos_encoding (str) – Positional encoding strategy: "additive_absolute", "additive_relative", "rope", or "none".

  • num_layers (int) – Number of Transformer decoder layers.

  • num_attention_heads (int) – Number of attention heads in Multi-Head Attention blocks.

  • embedding_dim (int) – Dimensionality of token embeddings.

  • p_drop_emb (float) – Dropout applied to embeddings.

  • p_drop_attn (float) – Dropout applied inside attention blocks.

  • causal_attn (bool) – Whether to use causal attention in self-attention and cross-attention layers.

  • num_cond_layers (int) – Number of encoder layers dedicated to conditioning tokens.

  • norm_first (bool) – Whether to apply normalization before each residual branch.

  • norm_module (str) – Final and per-layer normalization type: "layernorm" or "rmsnorm".

configure_optimizers(learning_rate=0.0001, weight_decay=0.001, betas=(0.9, 0.95))[source]#

Create optimizer groups separating parameters that receive weight decay from those that don’t.

Parameters:
forward(X, opt_step, Z)[source]#

Forward pass of the conditional transformer regressor, approximating V_k = f(X_k, Z, k).

Parameters:
  • X (Tensor) – The input/data sample composed of a trajectory of T_x points in a d_x-dimensional space. Shape: (B, T_x, d_x), where B is the batch size.

  • opt_step (Tensor | float | int) – The optimisation timestep(s) k at which to evaluate the regressor. Can be a single scalar or a tensor of shape (B,).

  • Z (Tensor) – The conditioning/observation variable composed of T_z points in a d_z-dimensional space. Shape: (B, T_z, d_z), where B is the batch size.

Returns:

The output regression variable of shape (B, T_x, d_v).

Return type:

Tensor

get_optim_groups(weight_decay=0.001)[source]#

Create optimizer groups separating parameters that receive weight decay from those that don’t.

Parameters:

weight_decay (float)