eTransformerDecoderLayer#

class eTransformerDecoderLayer(in_rep, self_attn, multihead_attn, dim_feedforward=2048, dropout=0.1, activation=GELU(approximate='none'), layer_norm_eps=1e-05, norm_first=True, norm_module='rmsnorm', bias=True, init_scheme='xavier_uniform')[source]#

Bases: eModule

Equivariant Transformer decoder layer mirroring torch.nn.TransformerDecoderLayer.

Combines an equivariant self-attention block (plain eMultiheadAttention or any caller-provided PositionalAttentionBase backend), an equivariant cross-attention block, and the same eLinear/equivariant normalization (eRMSNorm or eLayerNorm) feed-forward structure used by the encoder so every submodule commutes with the group action while keeping PyTorch’s runtime logic intact.

The layer defines:

\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_{\mathrm{tgt}}}\times\mathcal{X}^{T_{\mathrm{mem}}} \to \mathcal{X}^{T_{\mathrm{tgt}}}.\]

Functional equivariance constraint (assuming tgt and memory transform under the same representation):

\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{tgt},\rho_{\mathcal{X}}(g)\mathbf{mem}) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{tgt},\mathbf{mem}) \quad \forall g\in\mathbb{G}.\]

Create an equivariant Transformer decoder layer.

Parameters:
  • in_rep (Representation) – Input representation \(\rho_{\text{in}}\).

  • self_attn (eMultiheadAttention | PositionalAttentionBase) – Pre-built target self-attention module.

  • multihead_attn (eMultiheadAttention | PositionalAttentionBase) – Pre-built target-to-memory attention module.

  • dim_feedforward (int) – Hidden dimension of the feedforward network.

  • dropout (float) – Dropout probability.

  • activation (Module) – Activation module. Default: torch.nn.GELU().

  • layer_norm_eps (float) – Epsilon for layer normalization.

  • norm_first (bool) – If True, apply normalization before attention/feedforward.

  • norm_module (Literal['layernorm', 'rmsnorm']) – Normalization layer type ('layernorm' or 'rmsnorm').

  • bias (bool) – Whether to use bias in linear layers.

  • device – Tensor device.

  • dtype – Tensor dtype.

  • init_scheme (str | None) – Initialization scheme for equivariant layers.

check_equivariance(batch_size=4, tgt_len=3, mem_len=5, samples=20, atol=0.0001, rtol=0.0001)[source]#

Quick sanity check ensuring both attention blocks and the full layer are equivariant.

Return type:

None

Parameters:
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False, *, tgt_positions=None, memory_positions=None, tgt_position_mask=None, memory_position_mask=None)[source]#

Pass the input through the equivariant decoder layer.

Parameters:
  • tgt (Tensor) – target/query tensor of shape (B, T, D). The last dimension must equal in_rep.size.

  • memory (Tensor) – encoder memory tensor of shape (B, S, D). We assume this tensor transforms under the same representation as tgt; i.e., it is typically the output of an equivariant encoder with representation in_rep.

  • tgt_mask (Tensor | None) – optional target attention mask (same semantics as PyTorch’s API).

  • memory_mask (Tensor | None) – optional memory attention mask.

  • tgt_key_padding_mask (Tensor | None) – optional padding mask for the target batch.

  • memory_key_padding_mask (Tensor | None) – optional padding mask for the memory batch.

  • tgt_is_causal (bool) – if True, applies a causal mask to the target self-attention.

  • memory_is_causal (bool) – if True, applies a causal mask to the cross-attention.

  • tgt_positions (Tensor | None) – Optional target-token positions for positional attention backends.

  • memory_positions (Tensor | None) – Optional memory-token positions for positional attention backends.

  • tgt_position_mask (Tensor | None) – Optional boolean mask for valid target positions.

  • memory_position_mask (Tensor | None) – Optional boolean mask for valid memory positions.

Return type:

Tensor