TransformerDecoder#

class TransformerDecoder(decoder_layer, num_layers, norm=None)[source]#

Bases: Module

Stack decoder layers and apply an optional final normalization.

Attributes:#

layers:

Sequential copies of the decoder layer.

norm:

Optional normalization applied after the final layer.

num_layers:

Number of stacked decoder layers.

Initialize the decoder stack.

type decoder_layer:

Module

param decoder_layer:

Base layer to replicate.

type decoder_layer:

torch.nn.Module

type num_layers:

int

param num_layers:

Number of stacked decoder layers.

type num_layers:

int

type norm:

Module | None

param norm:

Final normalization layer.

type norm:

torch.nn.Module, optional

forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False, **layer_kwargs)[source]#

Apply the decoder stack to target and memory sequences.

Return type:

Tensor

Parameters:

Shape#

  • tgt: (B, T_t, D).

  • memory: (B, T_m, D).

  • Returns: decoded target with shape (B, T_t, D).

Parameters: