TransformerDecoder#
- class TransformerDecoder(decoder_layer, num_layers, norm=None)[source]#
Bases:
ModuleStack decoder layers and apply an optional final normalization.
Attributes:#
- layers:
Sequential copies of the decoder layer.
- norm:
Optional normalization applied after the final layer.
- num_layers:
Number of stacked decoder layers.
Initialize the decoder stack.
- type decoder_layer:
- param decoder_layer:
Base layer to replicate.
- type decoder_layer:
- type num_layers:
- param num_layers:
Number of stacked decoder layers.
- type num_layers:
- type norm:
- param norm:
Final normalization layer.
- type norm:
torch.nn.Module, optional
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False, **layer_kwargs)[source]#
Apply the decoder stack to target and memory sequences.
- Return type:
- Parameters:
Shape#
tgt:(B, T_t, D).memory:(B, T_m, D).Returns: decoded target with shape
(B, T_t, D).
- Parameters:
decoder_layer (torch.nn.Module)
num_layers (int)
norm (torch.nn.Module | None)