eCondTransformer#

class eCondTransformer(in_rep, cond_rep, out_rep, in_horizon, cond_horizon, num_layers, num_attention_heads, embedding_dim, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0, pos_encoding='additive_absolute', norm_first=True, norm_module='rmsnorm', init_scheme='xavier_uniform')[source]#

Bases: eModule, GenCondRegressor

Equivariant encoder/decoder Transformer with configurable positional attention.

Let \(A := \texttt{num\_cond\_layers}\) and \(B := \texttt{num\_layers}\). This module is the equivariant counterpart of CondTransformer: an encoder/decoder Transformer with \(A\) conditioning encoder layers and \(B\) decoder layers, following the architecture introduced in Attention Is All You Need by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, and Polosukhin (NeurIPS 2017), while constraining every learnable map to respect the prescribed group actions.

The conditioning stream is assembled as

\[[k, \mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}],\]

where \(k\) is the inference-time optimisation step token and \(\mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}\) are the conditioning tokens ordered from oldest to most recent observation. The decoder predicts the target sequence

\[[\mathbf{x}_{0}, \ldots, \mathbf{x}_{T_x - 1}],\]

ordered from the first action to the last action in the predicted horizon.

Supported positional encodings are:

"additive_absolute"

Uses eAdditivePosMultiheadAttention. Learned absolute positions are added in the equivariant embedding space before self-attention or cross-attention.

"additive_relative"

Uses eAdditiveRelMultiheadAttention. Learned relative distance biases are injected into attention logits, preserving the time-translation structure of the sequence while keeping the feature maps equivariant.

"none"

Uses eMultiheadAttention with no explicit positional encoding.

Temporal assumptions:

  • \(Z\) must already be ordered in time from past to present.

  • \(X\) must already be ordered from the first predicted action to the last predicted action.

  • For "additive_relative", the last conditioning token \(\mathbf{z}_{0}\) and the first action token \(\mathbf{x}_{0}\) are both placed at time index \(0\), so cross-attention is anchored at the present time.

  • The optimisation-step token \(k\) is prepended to the conditioning memory, but it is not treated as part of the observation timeline.

Equivariance is enforced by embedding tokens into a representation space built from copies of the regular representation, projecting the scalar step embedding onto the invariant subspace, and using equivariant encoder, decoder, normalization, and head layers throughout. The resulting conditional map satisfies

\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{X}_k,\, \rho_{\mathcal{Z}}(g)\mathbf{Z},\, k) = \rho_{\mathcal{Y}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{X}_k,\mathbf{Z},k), \qquad \forall g \in \mathbb{G}.\]

Create an equivariant conditional transformer regressor.

Parameters:
  • in_rep (Representation) – Representation \(\rho_{\text{in}}\) of the input tokens.

  • cond_rep (Representation) – Representation \(\rho_{\text{cond}}\) of the conditioning tokens.

  • out_rep (Representation, optional) – Output representation \(\rho_{\text{out}}\). Defaults to in_rep if None.

  • in_horizon (int) – Maximum length of the input sequence.

  • cond_horizon (int) – Maximum length of the conditioning sequence.

  • num_layers (int) – Number of transformer decoder layers in the main generation trunk.

  • num_attention_heads (int) – Number of attention heads in transformer layers.

  • embedding_dim (int) – Dimension of the regular representation embedding space. Must be a multiple of the group order.

  • p_drop_emb (float) – Dropout probability for embeddings.

  • p_drop_attn (float) – Dropout probability inside attention blocks.

  • causal_attn (bool) – Whether to mask future tokens (causal masking).

  • num_cond_layers (int) – Number of transformer encoder layers for processing conditioning tokens. If 0, an eMLP is used instead.

  • pos_encoding (Literal['additive_absolute', 'additive_relative', 'none']) – Positional attention backend ("additive_absolute", "additive_relative", or "none").

  • norm_first (bool) – Whether to apply normalization before each residual branch.

  • norm_module (Literal['layernorm', 'rmsnorm']) – Normalization layer type ('layernorm' or 'rmsnorm').

  • init_scheme (str) – Initialization scheme for equivariant layers.

forward(X, opt_step, Z)[source]#

Forward pass approximating \(V_k = f(X_k, Z, k)\).

Parameters:
get_optim_groups(weight_decay=0.001)[source]#

Todo.

Parameters:

weight_decay (float)

reset_parameters(scheme='xavier_uniform')[source]#

Re-initialize all parameters.

Return type:

None