TransformerDecoderLayer#
- class TransformerDecoderLayer(d_model, self_attn, multihead_attn, dim_feedforward=2048, dropout=0.1, activation=GELU(approximate='none'), layer_norm_eps=1e-05, norm_first=False, norm_module='rmsnorm', bias=True)[source]#
Bases:
ModuleTransformer decoder layer with optional positional self- and cross-attention.
Given a target sequence \(\mathbf{X} \in \mathbb{R}^{B \times T_t \times D}\) and a memory sequence \(\mathbf{M} \in \mathbb{R}^{B \times T_m \times D}\), the layer computes:
\[\begin{split}\mathbf{X}' &= \mathbf{X} + \operatorname{Dropout}\!\left( \operatorname{SelfAttn}(\mathbf{X}, \mathbf{X}, \mathbf{X})\right), \\ \mathbf{X}'' &= \mathbf{X}' + \operatorname{Dropout}\!\left( \operatorname{CrossAttn}(\mathbf{X}', \mathbf{M}, \mathbf{M})\right), \\ \mathbf{Y} &= \mathbf{X}'' + \operatorname{FFN}(\mathbf{X}''),\end{split}\]with layer normalization applied either before each residual branch (
norm_first=True, pre-norm) or after (norm_first=False, post-norm). Each attention operator is eithertorch.nn.MultiheadAttentionor anyPositionalAttentionBasesubclass.Attributes:#
- self_attn:
Attention module used for masked target self-attention.
- multihead_attn:
Attention module used for target-to-memory cross-attention.
- feed_forward_block:
Sequential feed-forward network applied after cross-attention.
- norm1, norm2, norm3:
Normalization layers (
RMSNormorLayerNorm) applied around the residual branches.- norm_first:
Whether normalization is applied before each residual branch.
Initialize the decoder layer.
- type d_model:
- param d_model:
Model width
D.- type d_model:
- type self_attn:
- param self_attn:
Masked self-attention module.
- type self_attn:
- type multihead_attn:
- param multihead_attn:
Cross-attention module.
- type multihead_attn:
- type dim_feedforward:
- param dim_feedforward:
Width of the feed-forward hidden layer. Default:
2048.- type dim_feedforward:
- type dropout:
- param dropout:
Dropout probability. Default:
0.1.- type dropout:
- type activation:
- param activation:
Feed-forward activation. Default:
torch.nn.functional.gelu().- type activation:
str| callable- type layer_norm_eps:
- param layer_norm_eps:
LayerNorm epsilon. Default:
1e-5.- type layer_norm_eps:
- type norm_first:
- param norm_first:
If
True, apply LayerNorm before each residual branch. Default:False.- type norm_first:
- type norm_module:
Literal['layernorm','rmsnorm']- param norm_module:
Normalization layer type (
'layernorm'or'rmsnorm'). Default:'rmsnorm'.- type norm_module:
- type bias:
- param bias:
If
True, use learnable normalization biases. Default:True.- type bias:
- param device:
Parameter factory options.
- type device:
torch.device, optional- param dtype:
Parameter factory options.
- type dtype:
torch.dtype, optional
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False, *, tgt_positions=None, memory_positions=None, tgt_position_mask=None, memory_position_mask=None)[source]#
Apply the decoder layer to target and memory sequences.
- Parameters:
tgt (
torch.Tensor) – Target sequence.memory (
torch.Tensor) – Memory sequence from the encoder.tgt_mask (
torch.Tensor, optional) – Additive attention mask for the target sequence.memory_mask (
torch.Tensor, optional) – Additive attention mask for the memory sequence.tgt_key_padding_mask (
torch.Tensor, optional) – Boolean mask for padded target elements.memory_key_padding_mask (
torch.Tensor, optional) – Boolean mask for padded memory elements.tgt_is_causal (
bool) – Whether to apply a causal attention mask to the target. Default:False.memory_is_causal (
bool) – Whether to apply a causal attention mask to the memory. Default:False.tgt_positions (
torch.Tensor, optional) – Absolute positions for the target sequence.memory_positions (
torch.Tensor, optional) – Absolute positions for the memory sequence.tgt_position_mask (
torch.Tensor, optional) – Boolean mask for padded target positions.memory_position_mask (
torch.Tensor, optional) – Boolean mask for padded memory positions.
- Returns:
The decoded target sequence.
- Return type:
Shape#
tgt:(B, T_t, D).memory:(B, T_m, D).tgt_positions,memory_positions:(T,)or(B, T).tgt_position_mask,memory_position_mask: boolean masks with the same layout as their corresponding positions.Returns: decoded target with shape
(B, T_t, D).
- Parameters:
d_model (int)
self_attn (torch.nn.MultiheadAttention | PositionalAttentionBase)
multihead_attn (torch.nn.MultiheadAttention | PositionalAttentionBase)
dim_feedforward (int)
dropout (float)
activation (torch.nn.Module)
layer_norm_eps (float)
norm_first (bool)
norm_module (Literal['layernorm', 'rmsnorm'])
bias (bool)