eTransformerDecoderLayer#

class eTransformerDecoderLayer(in_rep, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, batch_first=True, norm_first=True, norm_module='rmsnorm', bias=True, device=None, dtype=None, init_scheme='xavier_uniform')[source]#

Bases: Module

Equivariant Transformer decoder layer mirroring torch.nn.TransformerDecoderLayer.

Combines an equivariant self-attention block, an equivariant cross-attention block, and the same eLinear/ eLayerNorm feed-forward structure used by the encoder so every submodule commutes with the group action while keeping PyTorch’s runtime logic intact.

The layer defines:

\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_{\mathrm{tgt}}}\times\mathcal{X}^{T_{\mathrm{mem}}} \to \mathcal{X}^{T_{\mathrm{tgt}}}.\]

Functional equivariance constraint (assuming tgt and memory transform under the same representation):

\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{tgt},\rho_{\mathcal{X}}(g)\mathbf{mem}) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{tgt},\mathbf{mem}) \quad \forall g\in\mathbb{G}.\]

Create an equivariant Transformer decoder layer.

Parameters:
  • in_rep (Representation) – Input representation \(\rho_{\text{in}}\).

  • nhead (int) – Number of attention heads.

  • dim_feedforward (int) – Hidden dimension of the feedforward network.

  • dropout (float) – Dropout probability.

  • activation (str | Callable[[Tensor], Tensor]) – Activation function ('relu' or 'gelu').

  • layer_norm_eps (float) – Epsilon for layer normalization.

  • batch_first (bool) – If True, input/output shape is (B, T, D).

  • norm_first (bool) – If True, apply normalization before attention/feedforward.

  • norm_module (Literal['layernorm', 'rmsnorm']) – Normalization layer type ('layernorm' or 'rmsnorm').

  • bias (bool) – Whether to use bias in linear layers.

  • device – Tensor device.

  • dtype – Tensor dtype.

  • init_scheme (str | None) – Initialization scheme for equivariant layers.

check_equivariance(batch_size=4, tgt_len=3, mem_len=5, samples=20, atol=0.0001, rtol=0.0001)[source]#

Quick sanity check ensuring both attention blocks and the full layer are equivariant.

Return type:

None

Parameters:
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False)[source]#

Pass the input through the equivariant decoder layer.

Parameters:
  • tgt (Tensor) – target/query tensor of shape (T, B, D) or (B, T, D) matching batch_first. The last dimension must equal in_rep.size.

  • memory (Tensor) – encoder memory tensor of shape (S, B, D) or (B, S, D) (same batch_first). We assume this tensor transforms under the same representation as tgt; i.e., it is typically the output of an equivariant encoder with representation in_rep.

  • tgt_mask (Tensor | None) – optional target attention mask (same semantics as PyTorch’s API).

  • memory_mask (Tensor | None) – optional memory attention mask.

  • tgt_key_padding_mask (Tensor | None) – optional padding mask for the target batch.

  • memory_key_padding_mask (Tensor | None) – optional padding mask for the memory batch.

  • tgt_is_causal (bool) – if True, applies a causal mask to the target self-attention.

  • memory_is_causal (bool) – if True, applies a causal mask to the cross-attention.

Return type:

Tensor