TransformerDecoderLayer#

class TransformerDecoderLayer(d_model, self_attn, multihead_attn, dim_feedforward=2048, dropout=0.1, activation=GELU(approximate='none'), layer_norm_eps=1e-05, norm_first=False, norm_module='rmsnorm', bias=True)[source]#

Bases: Module

Transformer decoder layer with optional positional self- and cross-attention.

Given a target sequence \(\mathbf{X} \in \mathbb{R}^{B \times T_t \times D}\) and a memory sequence \(\mathbf{M} \in \mathbb{R}^{B \times T_m \times D}\), the layer computes:

\[\begin{split}\mathbf{X}' &= \mathbf{X} + \operatorname{Dropout}\!\left( \operatorname{SelfAttn}(\mathbf{X}, \mathbf{X}, \mathbf{X})\right), \\ \mathbf{X}'' &= \mathbf{X}' + \operatorname{Dropout}\!\left( \operatorname{CrossAttn}(\mathbf{X}', \mathbf{M}, \mathbf{M})\right), \\ \mathbf{Y} &= \mathbf{X}'' + \operatorname{FFN}(\mathbf{X}''),\end{split}\]

with layer normalization applied either before each residual branch (norm_first=True, pre-norm) or after (norm_first=False, post-norm). Each attention operator is either torch.nn.MultiheadAttention or any PositionalAttentionBase subclass.

Attributes:#

self_attn:

Attention module used for masked target self-attention.

multihead_attn:

Attention module used for target-to-memory cross-attention.

feed_forward_block:

Sequential feed-forward network applied after cross-attention.

norm1, norm2, norm3:

Normalization layers (RMSNorm or LayerNorm) applied around the residual branches.

norm_first:

Whether normalization is applied before each residual branch.

Initialize the decoder layer.

type d_model:

int

param d_model:

Model width D.

type d_model:

int

type self_attn:

MultiheadAttention | PositionalAttentionBase

param self_attn:

Masked self-attention module.

type self_attn:

torch.nn.MultiheadAttention | PositionalAttentionBase

type multihead_attn:

MultiheadAttention | PositionalAttentionBase

param multihead_attn:

Cross-attention module.

type multihead_attn:

torch.nn.MultiheadAttention | PositionalAttentionBase

type dim_feedforward:

int

param dim_feedforward:

Width of the feed-forward hidden layer. Default: 2048.

type dim_feedforward:

int

type dropout:

float

param dropout:

Dropout probability. Default: 0.1.

type dropout:

float

type activation:

Module

param activation:

Feed-forward activation. Default: torch.nn.functional.gelu().

type activation:

str | callable

type layer_norm_eps:

float

param layer_norm_eps:

LayerNorm epsilon. Default: 1e-5.

type layer_norm_eps:

float

type norm_first:

bool

param norm_first:

If True, apply LayerNorm before each residual branch. Default: False.

type norm_first:

bool

type norm_module:

Literal['layernorm', 'rmsnorm']

param norm_module:

Normalization layer type ('layernorm' or 'rmsnorm'). Default: 'rmsnorm'.

type norm_module:

str

type bias:

bool

param bias:

If True, use learnable normalization biases. Default: True.

type bias:

bool

param device:

Parameter factory options.

type device:

torch.device, optional

param dtype:

Parameter factory options.

type dtype:

torch.dtype, optional

forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False, *, tgt_positions=None, memory_positions=None, tgt_position_mask=None, memory_position_mask=None)[source]#

Apply the decoder layer to target and memory sequences.

Parameters:
  • tgt (torch.Tensor) – Target sequence.

  • memory (torch.Tensor) – Memory sequence from the encoder.

  • tgt_mask (torch.Tensor, optional) – Additive attention mask for the target sequence.

  • memory_mask (torch.Tensor, optional) – Additive attention mask for the memory sequence.

  • tgt_key_padding_mask (torch.Tensor, optional) – Boolean mask for padded target elements.

  • memory_key_padding_mask (torch.Tensor, optional) – Boolean mask for padded memory elements.

  • tgt_is_causal (bool) – Whether to apply a causal attention mask to the target. Default: False.

  • memory_is_causal (bool) – Whether to apply a causal attention mask to the memory. Default:False.

  • tgt_positions (torch.Tensor, optional) – Absolute positions for the target sequence.

  • memory_positions (torch.Tensor, optional) – Absolute positions for the memory sequence.

  • tgt_position_mask (torch.Tensor, optional) – Boolean mask for padded target positions.

  • memory_position_mask (torch.Tensor, optional) – Boolean mask for padded memory positions.

Returns:

The decoded target sequence.

Return type:

torch.Tensor

Shape#

  • tgt: (B, T_t, D).

  • memory: (B, T_m, D).

  • tgt_positions, memory_positions: (T,) or (B, T).

  • tgt_position_mask, memory_position_mask: boolean masks with the same layout as their corresponding positions.

  • Returns: decoded target with shape (B, T_t, D).

Parameters: