eCondTransformerRegressor#

class eCondTransformerRegressor(in_rep, cond_rep, out_rep, in_horizon, cond_horizon, num_layers, num_attention_heads, embedding_dim, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0, norm_module='rmsnorm', init_scheme='xavier_uniform')[source]#

Bases: GenCondRegressor

Equivariant analogue of the conditional transformer regressor baseline.

This module mirrors CondTransformerRegressor while enforcing equivariance constraints.

Tokens transforming according to in_rep are embedded into an embedding_rep space built from copies of the regular representation so that eTransformerEncoderLayer/ eTransformerDecoderLayer can be used directly. Positional encodings and timestep embeddings are projected onto the invariant subspace so they can be added to equivariant tokens without breaking symmetry.

The model defines:

\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_x} \times \mathcal{Z}^{T_z} \times \mathbb{R} \to \mathcal{Y}^{T_x}.\]

Functional equivariance constraint:

\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{X}_k,\, \rho_{\mathcal{Z}}(g)\mathbf{Z},\, k) = \rho_{\mathcal{Y}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{X}_k,\mathbf{Z},k), \quad \forall g\in\mathbb{G}.\]

Create an equivariant conditional transformer regressor.

Parameters:
  • in_rep (Representation) – Representation \(\rho_{\text{in}}\) of the input tokens.

  • cond_rep (Representation) – Representation \(\rho_{\text{cond}}\) of the conditioning tokens.

  • out_rep (Representation, optional) – Output representation \(\rho_{\text{out}}\). Defaults to in_rep if None.

  • in_horizon (int) – Maximum length of the input sequence.

  • cond_horizon (int) – Maximum length of the conditioning sequence.

  • num_layers (int) – Number of transformer decoder layers in the main generation trunk.

  • num_attention_heads (int) – Number of attention heads in transformer layers.

  • embedding_dim (int) – Dimension of the regular representation embedding space. Must be a multiple of the group order.

  • p_drop_emb (float) – Dropout probability for embeddings.

  • p_drop_attn (float) – Dropout probability inside attention blocks.

  • causal_attn (bool) – Whether to mask future tokens (causal masking).

  • num_cond_layers (int) – Number of transformer encoder layers for processing conditioning tokens. If 0, an eMLP is used instead.

  • norm_module (Literal['layernorm', 'rmsnorm']) – Normalization layer type ('layernorm' or 'rmsnorm').

  • init_scheme (str) – Initialization scheme for equivariant layers.

forward(X, opt_step, Z)[source]#

Forward pass approximating \(V_k = f(X_k, Z, k)\).

Parameters:
get_optim_groups(weight_decay=0.001)[source]#

Todo.

Parameters:

weight_decay (float)

reset_parameters(scheme='xavier_uniform')[source]#

Re-initialize all parameters.

Return type:

None