eCondTransformerRegressor#
- class eCondTransformerRegressor(in_rep, cond_rep, out_rep, in_horizon, cond_horizon, num_layers, num_attention_heads, embedding_dim, p_drop_emb=0.1, p_drop_attn=0.1, causal_attn=False, num_cond_layers=0, norm_module='rmsnorm', init_scheme='xavier_uniform')[source]#
Bases:
GenCondRegressorEquivariant analogue of the conditional transformer regressor baseline.
This module mirrors
CondTransformerRegressorwhile enforcing equivariance constraints.Tokens transforming according to
in_repare embedded into anembedding_repspace built from copies of the regular representation so thateTransformerEncoderLayer/eTransformerDecoderLayercan be used directly. Positional encodings and timestep embeddings are projected onto the invariant subspace so they can be added to equivariant tokens without breaking symmetry.The model defines:
\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_x} \times \mathcal{Z}^{T_z} \times \mathbb{R} \to \mathcal{Y}^{T_x}.\]Functional equivariance constraint:
\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{X}_k,\, \rho_{\mathcal{Z}}(g)\mathbf{Z},\, k) = \rho_{\mathcal{Y}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{X}_k,\mathbf{Z},k), \quad \forall g\in\mathbb{G}.\]Create an equivariant conditional transformer regressor.
- Parameters:
in_rep (
Representation) – Representation \(\rho_{\text{in}}\) of the input tokens.cond_rep (
Representation) – Representation \(\rho_{\text{cond}}\) of the conditioning tokens.out_rep (
Representation, optional) – Output representation \(\rho_{\text{out}}\). Defaults toin_repifNone.in_horizon (
int) – Maximum length of the input sequence.cond_horizon (
int) – Maximum length of the conditioning sequence.num_layers (
int) – Number of transformer decoder layers in the main generation trunk.num_attention_heads (
int) – Number of attention heads in transformer layers.embedding_dim (
int) – Dimension of the regular representation embedding space. Must be a multiple of the group order.p_drop_emb (
float) – Dropout probability for embeddings.p_drop_attn (
float) – Dropout probability inside attention blocks.causal_attn (
bool) – Whether to mask future tokens (causal masking).num_cond_layers (
int) – Number of transformer encoder layers for processing conditioning tokens. If 0, an eMLP is used instead.norm_module (
Literal['layernorm','rmsnorm']) – Normalization layer type ('layernorm'or'rmsnorm').init_scheme (
str) – Initialization scheme for equivariant layers.