eTransformerDecoderLayer#
- class eTransformerDecoderLayer(in_rep, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, batch_first=True, norm_first=True, norm_module='rmsnorm', bias=True, device=None, dtype=None, init_scheme='xavier_uniform')[source]#
Bases:
ModuleEquivariant Transformer decoder layer mirroring
torch.nn.TransformerDecoderLayer.Combines an equivariant self-attention block, an equivariant cross-attention block, and the same
eLinear/eLayerNormfeed-forward structure used by the encoder so every submodule commutes with the group action while keeping PyTorch’s runtime logic intact.The layer defines:
\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_{\mathrm{tgt}}}\times\mathcal{X}^{T_{\mathrm{mem}}} \to \mathcal{X}^{T_{\mathrm{tgt}}}.\]Functional equivariance constraint (assuming
tgtandmemorytransform under the same representation):\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{tgt},\rho_{\mathcal{X}}(g)\mathbf{mem}) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{tgt},\mathbf{mem}) \quad \forall g\in\mathbb{G}.\]Create an equivariant Transformer decoder layer.
- Parameters:
in_rep (
Representation) – Input representation \(\rho_{\text{in}}\).nhead (
int) – Number of attention heads.dim_feedforward (
int) – Hidden dimension of the feedforward network.dropout (
float) – Dropout probability.activation (
str|Callable[[Tensor],Tensor]) – Activation function ('relu'or'gelu').layer_norm_eps (
float) – Epsilon for layer normalization.batch_first (
bool) – IfTrue, input/output shape is(B, T, D).norm_first (
bool) – IfTrue, apply normalization before attention/feedforward.norm_module (
Literal['layernorm','rmsnorm']) – Normalization layer type ('layernorm'or'rmsnorm').bias (
bool) – Whether to use bias in linear layers.device – Tensor device.
dtype – Tensor dtype.
init_scheme (
str|None) – Initialization scheme for equivariant layers.
- check_equivariance(batch_size=4, tgt_len=3, mem_len=5, samples=20, atol=0.0001, rtol=0.0001)[source]#
Quick sanity check ensuring both attention blocks and the full layer are equivariant.
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False)[source]#
Pass the input through the equivariant decoder layer.
- Parameters:
tgt (
Tensor) – target/query tensor of shape(T, B, D)or(B, T, D)matchingbatch_first. The last dimension must equalin_rep.size.memory (
Tensor) – encoder memory tensor of shape(S, B, D)or(B, S, D)(samebatch_first). We assume this tensor transforms under the same representation astgt; i.e., it is typically the output of an equivariant encoder with representationin_rep.tgt_mask (
Tensor|None) – optional target attention mask (same semantics as PyTorch’s API).memory_mask (
Tensor|None) – optional memory attention mask.tgt_key_padding_mask (
Tensor|None) – optional padding mask for the target batch.memory_key_padding_mask (
Tensor|None) – optional padding mask for the memory batch.tgt_is_causal (
bool) – ifTrue, applies a causal mask to the target self-attention.memory_is_causal (
bool) – ifTrue, applies a causal mask to the cross-attention.
- Return type: