Source code for symm_learning.nn.transformer.transformer
from __future__ import annotations
import copy
from collections.abc import Callable
from typing import Literal
import torch
import torch.nn.functional as F
from symm_learning.nn.activation import PositionalAttentionBase
__all__ = [
"TransformerDecoder",
"TransformerDecoderLayer",
"TransformerEncoder",
"TransformerEncoderLayer",
]
[docs]
class TransformerEncoderLayer(torch.nn.Module):
r"""Transformer encoder layer with optional positional attention.
Given an input sequence :math:`\mathbf{X} \in \mathbb{R}^{B \times T \times D}`, the layer computes:
.. math::
\mathbf{X}' = \mathbf{X} + \operatorname{Dropout}\!\left(
\operatorname{Attn}(\mathbf{X}, \mathbf{X}, \mathbf{X})\right),
\qquad
\mathbf{Y} = \mathbf{X}' + \operatorname{FFN}(\mathbf{X}'),
with layer normalization applied either before each residual branch (``norm_first=True``, pre-norm)
or after (``norm_first=False``, post-norm). The attention operator :math:`\operatorname{Attn}` is either
:class:`torch.nn.MultiheadAttention` or any :class:`~symm_learning.nn.activation.PositionalAttentionBase`
subclass, the latter injecting positional information into query and key streams.
Attributes:
----------
self_attn:
Attention module used for the source self-attention block.
feed_forward_block:
Sequential feed-forward network applied after self-attention.
norm1, norm2:
Normalization layers (``RMSNorm`` or ``LayerNorm``) applied around the residual branches.
norm_first:
Whether normalization is applied before each residual branch.
"""
__constants__ = ["norm_first"]
def __init__(
self,
d_model: int,
self_attn: torch.nn.MultiheadAttention | PositionalAttentionBase,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: torch.nn.Module = torch.nn.GELU(),
layer_norm_eps: float = 1e-5,
norm_first: bool = False,
norm_module: Literal["layernorm", "rmsnorm"] = "rmsnorm",
bias: bool = True,
) -> None:
r"""Initialize the encoder layer.
Args:
d_model (:class:`int`): Model width ``D``.
self_attn (:class:`torch.nn.MultiheadAttention` | :class:`~symm_learning.nn.activation.PositionalAttentionBase`): Self-attention module.
dim_feedforward (:class:`int`): Width of the feed-forward hidden layer. Default: ``2048``.
dropout (:class:`float`): Dropout probability. Default: ``0.1``.
activation (:class:`str` | callable): Feed-forward activation. Default: :func:`torch.nn.functional.gelu`.
layer_norm_eps (:class:`float`): LayerNorm epsilon. Default: ``1e-5``.
norm_first (:class:`bool`): If ``True``, apply LayerNorm before each residual branch. Default: ``False``.
norm_module (:class:`str`): Normalization layer type (``'layernorm'`` or ``'rmsnorm'``). Default: ``'rmsnorm'``.
bias (:class:`bool`): If ``True``, use learnable normalization biases. Default: ``True``.
device (:class:`torch.device`, optional): Parameter factory options.
dtype (:class:`torch.dtype`, optional): Parameter factory options.
""" # noqa: E501
super().__init__()
self.self_attn = self_attn
assert isinstance(activation, torch.nn.Module), f"activation must be a torch.nn.Module got {type(activation)}"
self.feed_forward_block = torch.nn.Sequential(
torch.nn.Linear(d_model, dim_feedforward, bias=bias),
activation,
torch.nn.Dropout(dropout),
torch.nn.Linear(dim_feedforward, d_model, bias=bias),
torch.nn.Dropout(dropout),
)
self.norm_first = norm_first
if norm_module == "layernorm":
self.norm1 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
self.norm2 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
elif norm_module == "rmsnorm":
self.norm1 = torch.nn.RMSNorm(d_model, eps=layer_norm_eps)
self.norm2 = torch.nn.RMSNorm(d_model, eps=layer_norm_eps)
else:
raise ValueError(f"norm_module must be 'layernorm' or 'rmsnorm', got {norm_module}")
self.attn_dropout = torch.nn.Dropout(dropout)
[docs]
def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
is_causal: bool = False,
*,
src_positions: torch.Tensor | None = None,
src_position_mask: torch.Tensor | None = None,
) -> torch.Tensor:
r"""Apply the encoder layer to a batch-first source sequence.
Args:
src (:class:`torch.Tensor`): Input sequence.
src_mask (:class:`torch.Tensor`, optional): Additive attention mask for the source sequence.
src_key_padding_mask (:class:`torch.Tensor`, optional): Boolean mask for padded source elements.
is_causal (:class:`bool`): Whether to apply a causal attention mask. Default: ``False``.
src_positions (:class:`torch.Tensor`, optional): Absolute positions for the source sequence,
used by positional attention backends.
src_position_mask (:class:`torch.Tensor`, optional): Boolean mask for padded source positions,
used by positional attention backends.
Returns:
:class:`torch.Tensor`: The encoded source sequence.
Shape
-----
- ``src``: ``(B, T, D)``.
- ``src_positions``: ``(T,)`` or ``(B, T)``.
- ``src_position_mask``: boolean mask with the same layout as
``src_positions``._ff_block
- Returns: encoded source with shape ``(B, T, D)``.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(src_mask),
other_name="src_mask",
target_type=src.dtype,
)
src_mask = F._canonical_mask(
mask=src_mask,
mask_name="src_mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
x = src
if self.norm_first:
x = x + self._sa_block(
self.norm1(x), src_mask, src_key_padding_mask, is_causal, src_positions, src_position_mask
)
x = x + self.feed_forward_block(self.norm2(x))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal, src_positions, src_position_mask)
)
x = self.norm2(x + self.feed_forward_block(x))
return x
def _sa_block(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None,
key_padding_mask: torch.Tensor | None,
is_causal: bool,
positions: torch.Tensor | None,
position_mask: torch.Tensor | None,
) -> torch.Tensor:
"""Apply self-attention with optional positional encoding."""
pos_enc_kwargs = {}
if isinstance(self.self_attn, PositionalAttentionBase):
pos_enc_kwargs.update(
q_positions=positions,
k_positions=positions,
q_position_mask=position_mask,
k_position_mask=position_mask,
)
x = self.self_attn(
query=x,
key=x,
value=x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
**pos_enc_kwargs,
)[0]
return self.attn_dropout(x)
[docs]
class TransformerDecoderLayer(torch.nn.Module):
r"""Transformer decoder layer with optional positional self- and cross-attention.
Given a target sequence :math:`\mathbf{X} \in \mathbb{R}^{B \times T_t \times D}` and a memory sequence
:math:`\mathbf{M} \in \mathbb{R}^{B \times T_m \times D}`, the layer computes:
.. math::
\mathbf{X}' &= \mathbf{X} + \operatorname{Dropout}\!\left(
\operatorname{SelfAttn}(\mathbf{X}, \mathbf{X}, \mathbf{X})\right), \\
\mathbf{X}'' &= \mathbf{X}' + \operatorname{Dropout}\!\left(
\operatorname{CrossAttn}(\mathbf{X}', \mathbf{M}, \mathbf{M})\right), \\
\mathbf{Y} &= \mathbf{X}'' + \operatorname{FFN}(\mathbf{X}''),
with layer normalization applied either before each residual branch (``norm_first=True``, pre-norm)
or after (``norm_first=False``, post-norm). Each attention operator is either
:class:`torch.nn.MultiheadAttention` or any :class:`~symm_learning.nn.activation.PositionalAttentionBase`
subclass.
Attributes:
----------
self_attn:
Attention module used for masked target self-attention.
multihead_attn:
Attention module used for target-to-memory cross-attention.
feed_forward_block:
Sequential feed-forward network applied after cross-attention.
norm1, norm2, norm3:
Normalization layers (``RMSNorm`` or ``LayerNorm``) applied around the residual branches.
norm_first:
Whether normalization is applied before each residual branch.
"""
__constants__ = ["norm_first"] # For jit compilation
def __init__(
self,
d_model: int,
self_attn: torch.nn.MultiheadAttention | PositionalAttentionBase,
multihead_attn: torch.nn.MultiheadAttention | PositionalAttentionBase,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: torch.nn.Module = torch.nn.GELU(),
layer_norm_eps: float = 1e-5,
norm_first: bool = False,
norm_module: Literal["layernorm", "rmsnorm"] = "rmsnorm",
bias: bool = True,
) -> None:
r"""Initialize the decoder layer.
Args:
d_model (:class:`int`): Model width ``D``.
self_attn (:class:`torch.nn.MultiheadAttention` | :class:`~symm_learning.nn.activation.PositionalAttentionBase`):
Masked self-attention module.
multihead_attn (:class:`torch.nn.MultiheadAttention` | :class:`~symm_learning.nn.activation.PositionalAttentionBase`):
Cross-attention module.
dim_feedforward (:class:`int`): Width of the feed-forward hidden layer. Default: ``2048``.
dropout (:class:`float`): Dropout probability. Default: ``0.1``.
activation (:class:`str` | callable): Feed-forward activation. Default: :func:`torch.nn.functional.gelu`.
layer_norm_eps (:class:`float`): LayerNorm epsilon. Default: ``1e-5``.
norm_first (:class:`bool`): If ``True``, apply LayerNorm before each residual branch. Default: ``False``.
norm_module (:class:`str`): Normalization layer type (``'layernorm'`` or ``'rmsnorm'``). Default: ``'rmsnorm'``.
bias (:class:`bool`): If ``True``, use learnable normalization biases. Default: ``True``.
device (:class:`torch.device`, optional): Parameter factory options.
dtype (:class:`torch.dtype`, optional): Parameter factory options.
""" # noqa: E501
super().__init__()
self.self_attn = self_attn
self.multihead_attn = multihead_attn
assert isinstance(activation, torch.nn.Module), f"activation must be a torch.nn.Module got {type(activation)}"
self.feed_forward_block = torch.nn.Sequential(
torch.nn.Linear(d_model, dim_feedforward, bias=bias),
activation,
torch.nn.Dropout(dropout),
torch.nn.Linear(dim_feedforward, d_model, bias=bias),
torch.nn.Dropout(dropout),
)
self.norm_first = norm_first
if norm_module == "layernorm":
self.norm1 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
self.norm2 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
self.norm3 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
elif norm_module == "rmsnorm":
self.norm1 = torch.nn.RMSNorm(d_model, eps=layer_norm_eps)
self.norm2 = torch.nn.RMSNorm(d_model, eps=layer_norm_eps)
self.norm3 = torch.nn.RMSNorm(d_model, eps=layer_norm_eps)
else:
raise ValueError(f"norm_module must be 'layernorm' or 'rmsnorm', got {norm_module}")
self.attn_dropout = torch.nn.Dropout(dropout)
[docs]
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
tgt_key_padding_mask: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
*,
tgt_positions: torch.Tensor | None = None,
memory_positions: torch.Tensor | None = None,
tgt_position_mask: torch.Tensor | None = None,
memory_position_mask: torch.Tensor | None = None,
) -> torch.Tensor:
r"""Apply the decoder layer to target and memory sequences.
Args:
tgt (:class:`torch.Tensor`): Target sequence.
memory (:class:`torch.Tensor`): Memory sequence from the encoder.
tgt_mask (:class:`torch.Tensor`, optional): Additive attention mask for the target sequence.
memory_mask (:class:`torch.Tensor`, optional): Additive attention mask for the memory sequence.
tgt_key_padding_mask (:class:`torch.Tensor`, optional): Boolean mask for padded target elements.
memory_key_padding_mask (:class:`torch.Tensor`, optional): Boolean mask for padded memory elements.
tgt_is_causal (:class:`bool`): Whether to apply a causal attention mask to the target. Default: ``False``.
memory_is_causal (:class:`bool`): Whether to apply a causal attention mask to the memory. Default:``False``.
tgt_positions (:class:`torch.Tensor`, optional): Absolute positions for the target sequence.
memory_positions (:class:`torch.Tensor`, optional): Absolute positions for the memory sequence.
tgt_position_mask (:class:`torch.Tensor`, optional): Boolean mask for padded target positions.
memory_position_mask (:class:`torch.Tensor`, optional): Boolean mask for padded memory positions.
Returns:
:class:`torch.Tensor`: The decoded target sequence.
Shape
-----
- ``tgt``: ``(B, T_t, D)``.
- ``memory``: ``(B, T_m, D)``.
- ``tgt_positions``, ``memory_positions``: ``(T,)`` or ``(B, T)``.
- ``tgt_position_mask``, ``memory_position_mask``: boolean masks with
the same layout as their corresponding positions.
- Returns: decoded target with shape ``(B, T_t, D)``.
"""
x = tgt
if self.norm_first:
x = x + self._sa_block(
self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal, tgt_positions, tgt_position_mask
)
x = x + self._mha_block(
self.norm2(x),
memory,
memory_mask,
memory_key_padding_mask,
memory_is_causal,
tgt_positions,
memory_positions,
tgt_position_mask,
memory_position_mask,
)
x = x + self.feed_forward_block(self.norm3(x))
else:
x = self.norm1(
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal, tgt_positions, tgt_position_mask)
)
x = self.norm2(
x
+ self._mha_block(
x,
memory,
memory_mask,
memory_key_padding_mask,
memory_is_causal,
tgt_positions,
memory_positions,
tgt_position_mask,
memory_position_mask,
)
)
x = self.norm3(x + self.feed_forward_block(x))
return x
def _sa_block(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None,
key_padding_mask: torch.Tensor | None,
is_causal: bool,
positions: torch.Tensor | None,
position_mask: torch.Tensor | None,
) -> torch.Tensor:
"""Apply masked self-attention with optional positional encoding."""
kwargs = {}
if isinstance(self.self_attn, PositionalAttentionBase):
kwargs.update(
q_positions=positions,
k_positions=positions,
q_position_mask=position_mask,
k_position_mask=position_mask,
)
x = self.self_attn(
query=x,
key=x,
value=x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
**kwargs,
)[0]
return self.attn_dropout(x)
def _mha_block(
self,
x: torch.Tensor,
mem: torch.Tensor,
attn_mask: torch.Tensor | None,
key_padding_mask: torch.Tensor | None,
is_causal: bool,
q_positions: torch.Tensor | None,
k_positions: torch.Tensor | None,
q_position_mask: torch.Tensor | None,
k_position_mask: torch.Tensor | None,
) -> torch.Tensor:
"""Apply cross-attention with optional positional encoding."""
kwargs = {}
if isinstance(self.multihead_attn, PositionalAttentionBase):
kwargs.update(
q_positions=q_positions,
k_positions=k_positions,
q_position_mask=q_position_mask,
k_position_mask=k_position_mask,
)
x = self.multihead_attn(
query=x,
key=mem,
value=mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
**kwargs,
)[0]
return self.attn_dropout(x)
[docs]
class TransformerEncoder(torch.nn.Module):
r"""Stack encoder layers and apply an optional final normalization.
Attributes:
----------
layers:
Sequential copies of the encoder layer.
norm:
Optional normalization applied after the final layer.
num_layers:
Number of stacked encoder layers.
"""
def __init__(
self,
encoder_layer: torch.nn.Module,
num_layers: int,
norm: torch.nn.Module | None = None,
enable_nested_tensor: bool = True,
mask_check: bool = True,
) -> None:
r"""Initialize the encoder stack.
Args:
encoder_layer (:class:`torch.nn.Module`): Base layer to replicate.
num_layers (:class:`int`): Number of stacked encoder layers.
norm (:class:`torch.nn.Module`, optional): Final normalization layer.
enable_nested_tensor (:class:`bool`): Preserved for API compatibility. Default: ``True``.
mask_check (:class:`bool`): Preserved for API compatibility. Default: ``True``.
"""
super().__init__()
if num_layers <= 0:
raise ValueError(f"num_layers must be positive, got {num_layers}")
self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
self.enable_nested_tensor = enable_nested_tensor
self.use_nested_tensor = False
self.mask_check = mask_check
[docs]
def forward(
self,
src: torch.Tensor,
mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
is_causal: bool | None = None,
**layer_kwargs,
) -> torch.Tensor:
r"""Apply the encoder stack to a batch-first source sequence.
Shape
-----
- ``src``: ``(B, T, D)``.
- Returns: encoded source with shape ``(B, T, D)``.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(mask),
other_name="mask",
target_type=src.dtype,
)
mask = F._canonical_mask(
mask=mask,
mask_name="mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
output = src
seq_len = src.shape[-2]
# A square subsequent mask means the source is being decoded autoregressively.
is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
is_causal=is_causal,
**layer_kwargs,
)
if self.norm is not None:
output = self.norm(output)
return output
[docs]
class TransformerDecoder(torch.nn.Module):
r"""Stack decoder layers and apply an optional final normalization.
Attributes:
----------
layers:
Sequential copies of the decoder layer.
norm:
Optional normalization applied after the final layer.
num_layers:
Number of stacked decoder layers.
"""
def __init__(
self,
decoder_layer: torch.nn.Module,
num_layers: int,
norm: torch.nn.Module | None = None,
) -> None:
r"""Initialize the decoder stack.
Args:
decoder_layer (:class:`torch.nn.Module`): Base layer to replicate.
num_layers (:class:`int`): Number of stacked decoder layers.
norm (:class:`torch.nn.Module`, optional): Final normalization layer.
"""
super().__init__()
if num_layers <= 0:
raise ValueError(f"num_layers must be positive, got {num_layers}")
self.layers = torch.nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
[docs]
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
tgt_key_padding_mask: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
tgt_is_causal: bool | None = None,
memory_is_causal: bool = False,
**layer_kwargs,
) -> torch.Tensor:
r"""Apply the decoder stack to target and memory sequences.
Shape
-----
- ``tgt``: ``(B, T_t, D)``.
- ``memory``: ``(B, T_m, D)``.
- Returns: decoded target with shape ``(B, T_t, D)``.
"""
output = tgt
seq_len = tgt.shape[-2]
# Only the target mask can imply autoregressive decoding.
tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
for mod in self.layers:
output = mod(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal,
memory_is_causal=memory_is_causal,
**layer_kwargs,
)
if self.norm is not None:
output = self.norm(output)
return output
def _detect_is_causal_mask(
mask: torch.Tensor | None,
is_causal: bool | None = None,
size: int | None = None,
) -> bool:
"""Infer whether a square attention mask represents causal decoding."""
make_causal = is_causal is True
if is_causal is None and mask is not None:
causal_mask = torch.nn.Transformer.generate_square_subsequent_mask(
size if size is not None else mask.shape[-2],
device=mask.device,
dtype=mask.dtype,
)
if mask.size() == causal_mask.size():
make_causal = bool((mask == causal_mask).all())
return make_causal