eTransformerDecoderLayer#
- class eTransformerDecoderLayer(in_rep, self_attn, multihead_attn, dim_feedforward=2048, dropout=0.1, activation=GELU(approximate='none'), layer_norm_eps=1e-05, norm_first=True, norm_module='rmsnorm', bias=True, init_scheme='xavier_uniform')[source]#
Bases:
eModuleEquivariant Transformer decoder layer mirroring
torch.nn.TransformerDecoderLayer.Combines an equivariant self-attention block (plain
eMultiheadAttentionor any caller-providedPositionalAttentionBasebackend), an equivariant cross-attention block, and the sameeLinear/equivariant normalization (eRMSNormoreLayerNorm) feed-forward structure used by the encoder so every submodule commutes with the group action while keeping PyTorch’s runtime logic intact.The layer defines:
\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_{\mathrm{tgt}}}\times\mathcal{X}^{T_{\mathrm{mem}}} \to \mathcal{X}^{T_{\mathrm{tgt}}}.\]Functional equivariance constraint (assuming
tgtandmemorytransform under the same representation):\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{tgt},\rho_{\mathcal{X}}(g)\mathbf{mem}) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{tgt},\mathbf{mem}) \quad \forall g\in\mathbb{G}.\]Create an equivariant Transformer decoder layer.
- Parameters:
in_rep (
Representation) – Input representation \(\rho_{\text{in}}\).self_attn (
eMultiheadAttention|PositionalAttentionBase) – Pre-built target self-attention module.multihead_attn (
eMultiheadAttention|PositionalAttentionBase) – Pre-built target-to-memory attention module.dim_feedforward (
int) – Hidden dimension of the feedforward network.dropout (
float) – Dropout probability.activation (
Module) – Activation module. Default:torch.nn.GELU().layer_norm_eps (
float) – Epsilon for layer normalization.norm_first (
bool) – IfTrue, apply normalization before attention/feedforward.norm_module (
Literal['layernorm','rmsnorm']) – Normalization layer type ('layernorm'or'rmsnorm').bias (
bool) – Whether to use bias in linear layers.device – Tensor device.
dtype – Tensor dtype.
init_scheme (
str|None) – Initialization scheme for equivariant layers.
- check_equivariance(batch_size=4, tgt_len=3, mem_len=5, samples=20, atol=0.0001, rtol=0.0001)[source]#
Quick sanity check ensuring both attention blocks and the full layer are equivariant.
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False, *, tgt_positions=None, memory_positions=None, tgt_position_mask=None, memory_position_mask=None)[source]#
Pass the input through the equivariant decoder layer.
- Parameters:
tgt (
Tensor) – target/query tensor of shape(B, T, D). The last dimension must equalin_rep.size.memory (
Tensor) – encoder memory tensor of shape(B, S, D). We assume this tensor transforms under the same representation astgt; i.e., it is typically the output of an equivariant encoder with representationin_rep.tgt_mask (
Tensor|None) – optional target attention mask (same semantics as PyTorch’s API).memory_mask (
Tensor|None) – optional memory attention mask.tgt_key_padding_mask (
Tensor|None) – optional padding mask for the target batch.memory_key_padding_mask (
Tensor|None) – optional padding mask for the memory batch.tgt_is_causal (
bool) – ifTrue, applies a causal mask to the target self-attention.memory_is_causal (
bool) – ifTrue, applies a causal mask to the cross-attention.tgt_positions (
Tensor|None) – Optional target-token positions for positional attention backends.memory_positions (
Tensor|None) – Optional memory-token positions for positional attention backends.tgt_position_mask (
Tensor|None) – Optional boolean mask for valid target positions.memory_position_mask (
Tensor|None) – Optional boolean mask for valid memory positions.
- Return type: