Source code for symm_learning.nn.activation

from __future__ import annotations

import logging
from abc import ABC, abstractmethod

import torch
import torch.nn.functional as F
from escnn.group import Representation
from torch.nn.utils import parametrize

from symm_learning.linalg import invariant_orthogonal_projector
from symm_learning.nn.linear import eLinear
from symm_learning.nn.module import eModule
from symm_learning.nn.parametrizations import CommutingConstraint, InvariantConstraint
from symm_learning.representation_theory import direct_sum

logger = logging.getLogger(__name__)


[docs] class eMultiheadAttention(eModule, torch.nn.MultiheadAttention): """Drop-in replacement for :class:`torch.nn.MultiheadAttention` that preserves G-equivariance. This module keeps the runtime logic of PyTorch’s implementation untouched: we still rely on the packed ``in_proj_weight`` / ``in_proj_bias`` for computing queries, keys, and values, and the internal attention kernel (including mask handling, dropouts, and softmax) is exactly the stock MultiheadAttention behavior. Equivariance is achieved by constraining every linear projection involved in the attention block: * the input projection ``[Q; K; V] = W_in @ x`` is treated as a single map from the input representation to three stacked copies of a regular-representation block that aligns with the requested ``num_heads`` (enforced via :class:`~symm_learning.nn.parametrizations.CommutingConstraint`); * the optional stacked bias is projected onto the invariant subspace of that same block via :class:`~symm_learning.nn.parametrizations.InvariantConstraint`; * the output projection ``out_proj`` is constrained to commute with the group action so that the concatenated value vectors are mapped back into the original feature space equivariantly. Additionally, we restrict ``num_heads`` to divide the number of regular-representation copies present in the input feature space to avoid splitting irreducible subspaces across heads. """ def __init__( self, in_rep: Representation, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, device=None, dtype=None, init_scheme: str | None = "xavier_normal", ) -> None: r"""Initialize the equivariant multihead attention. Args: in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` of the input/output space. num_heads (:class:`int`): Number of parallel attention heads. dropout (:class:`float`): Dropout probability on attention weights. Default: 0.0. bias (:class:`bool`): If ``True``, adds learnable input and output projection biases. Default: ``True``. add_bias_kv (:class:`bool`): **Not supported**. Must be ``False``. add_zero_attn (:class:`bool`): **Not supported**. Must be ``False``. device (:class:`torch.device`, optional): Parameter factory options. dtype (:class:`torch.dtype`, optional): Parameter factory options. init_scheme (:class:`str` | :class:`None`, optional): Initialization scheme for the equivariant linear layers. Default: ``"xavier_normal"``. """ if num_heads <= 0: raise ValueError(f"num_heads must be positive, got {num_heads}") if add_bias_kv: raise NotImplementedError("Equivariant attention does not support add_bias_kv.") if add_zero_attn: raise NotImplementedError("Equivariant attention does not support add_zero_attn.") G = in_rep.group if in_rep.size % G.order() != 0: raise ValueError(f"Input rep dim ({in_rep.size}) must be divisible of the group order ({G.order()}).") regular_copies = in_rep.size // G.order() if regular_copies % num_heads != 0: raise ValueError(f"For input dim {in_rep.size} `num_heads` must divide {in_rep.size}/|G|={regular_copies}") super().__init__( embed_dim=in_rep.size, num_heads=num_heads, dropout=dropout, bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, batch_first=True, device=device, dtype=dtype, ) self.in_rep, self.out_rep = in_rep, in_rep self._regular_stack_rep = direct_sum([G.regular_representation] * regular_copies) if not self._qkv_same_embed_dim: raise ValueError("eMultiheadAttention requires kdim == vdim == embed_dim.") stacked_qkv_rep = direct_sum([G.regular_representation] * regular_copies * 3) parametrize.register_parametrization(self, "in_proj_weight", CommutingConstraint(in_rep, stacked_qkv_rep)) if bias and self.in_proj_bias is not None: parametrize.register_parametrization(self, "in_proj_bias", InvariantConstraint(stacked_qkv_rep)) # Replace output projection linear layer. self.out_proj = eLinear(in_rep, in_rep, bias=bias, init_scheme=init_scheme).to(device=device, dtype=dtype) if init_scheme is not None: self.reset_parameters(scheme=init_scheme)
[docs] @torch.no_grad() def reset_parameters(self, scheme="xavier_uniform") -> None: """Overload parent method to take into account equivariance constraints.""" if not hasattr(self, "parametrizations"): return super()._reset_parameters() logger.debug(f"Resetting parameters of {self.__class__.__name__} with scheme: {scheme}") # Reset equivariant linear layers (symm_learning.nn.eLinear) self.out_proj.reset_parameters(scheme=scheme) for param_name, constaint_list in self.parametrizations.items(): param = getattr(self, param_name) if param.dim() == 2: commuting_constraint: CommutingConstraint = constaint_list[0] W = commuting_constraint.homo_basis.initialize_params(scheme=scheme, return_dense=True) param = W logger.debug(f"Initialized {param_name} with scheme {scheme}") elif param.dim() == 1: # invariant_constraint: InvariantConstraint = constaint_list[0] param = torch.zeros_like(param) logger.debug(f"Initialized {param_name} with zeros")
# if self._qkv_same_embed_dim: # xavier_uniform_(self.in_proj_weight) # if self.in_proj_bias is not None: # constant_(self.in_proj_bias, 0.0) # constant_(self.out_proj.bias, 0.0)
[docs] class PositionalAttentionBase(torch.nn.Module, ABC): r"""Abstract interface for attention blocks with explicit positional branches. The convention in this module is that positional information is provided as a function, defined by a submodule, acting on the query and key streams only. Values are left unchanged: .. math:: ilde{\mathbf{q}} = \mathbf{q} + \phi_\theta(\mathbf{P}_Q), \qquad ilde{\mathbf{k}} = \mathbf{k} + \phi_\theta(\mathbf{P}_K), \qquad ilde{\mathbf{v}} = \mathbf{v}. A standard multi-head attention operator is then applied to :math:`(\tilde{\mathbf{q}}, \tilde{\mathbf{k}}, \tilde{\mathbf{v}})`. If the positional branch is the identity/no-op map, the module reduces to :class:`torch.nn.MultiheadAttention`. Concrete implementations may ignore any positional arguments they do not use, but the forward signature stays stable so encoder and decoder layers can call all attention backends uniformly. Shape ----- - ``query``, ``key``, ``value``: ``(B, P, D)``. - ``q_positions``, ``k_positions``: ``(P,)`` or ``(B, P)``. When omitted, they default to ``torch.arange(P)`` for the corresponding sequence. - ``q_position_mask``, ``k_position_mask``: ``(P,)`` or ``(B, P)`` boolean masks. - ``attn_mask``: any attention mask layout accepted by :class:`torch.nn.MultiheadAttention`. - ``key_padding_mask``: ``(B, S)`` boolean or additive padding mask. - Returns: attention output with the same leading layout as the input, plus optional attention weights. Attributes: ---------- This base class does not define any storage beyond the standard :class:`torch.nn.Module` state. Concrete subclasses define their own positional encoder and attention parameters. """
[docs] @abstractmethod def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *, q_positions: torch.Tensor | None = None, k_positions: torch.Tensor | None = None, q_position_mask: torch.Tensor | None = None, k_position_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, need_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Apply attention with optional explicit position metadata. Shape ----- - ``query``, ``key``, ``value``: see :class:`PositionalAttentionBase`. - ``q_positions``, ``k_positions``: position coordinates for the query and key tokens. They may differ in cross-attention. - Returns: ``(output, attn_weights)`` where ``attn_weights`` is ``None`` unless ``need_weights=True``. """ raise NotImplementedError
[docs] def reset_parameters(self) -> None: """Initialize parameters to match torch.nn.MultiheadAttention.""" if hasattr(self, "attn") and isinstance(self.attn, torch.nn.MultiheadAttention): self.attn._reset_parameters() return weight_names = ["in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"] for name in weight_names: weight = getattr(self, name, None) if weight is not None: torch.nn.init.xavier_uniform_(weight) elif hasattr(self, name.replace("_weight", "")) and isinstance( getattr(self, name.replace("_weight", "")), torch.nn.Linear ): torch.nn.init.xavier_uniform_(getattr(self, name.replace("_weight", "")).weight) bias_names = ["in_proj_bias", "bias_k", "bias_v"] for name in bias_names: bias = getattr(self, name, None) if bias is not None: torch.nn.init.constant_(bias, 0.0) elif ( hasattr(self, name.replace("_bias", "").replace("bias_", "") + "_proj") and isinstance(getattr(self, name.replace("_bias", "").replace("bias_", "") + "_proj"), torch.nn.Linear) and getattr(self, name.replace("_bias", "").replace("bias_", "") + "_proj").bias is not None ): torch.nn.init.constant_( getattr(self, name.replace("_bias", "").replace("bias_", "") + "_proj").bias, 0.0 ) out_proj = getattr(self, "out_proj", None) if out_proj is not None: if getattr(out_proj, "weight", None) is not None: torch.nn.init.xavier_uniform_(out_proj.weight) if getattr(out_proj, "bias", None) is not None: torch.nn.init.constant_(out_proj.bias, 0.0)
@staticmethod def _positions_or_arange( positions: torch.Tensor | None, *, seq_len: int, device: torch.device, ) -> torch.Tensor: r"""Use explicit positions when provided, otherwise default to ``0, \ldots, P-1``.""" if positions is None: return torch.arange(seq_len, device=device) return positions @staticmethod def _normalize_positions( positions: torch.Tensor, position_mask: torch.Tensor | None, *, batch_size: int, seq_len: int, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Normalize position tensors and validate their layout. The helper accepts positions with shape ``(P,)`` or ``(B, P)`` and checks that the sequence length matches ``seq_len`` and the batch dimension is either ``1`` or ``batch_size``. If a mask is provided, it must have the same rank and shape as ``positions`` before any singleton batch expansion. The returned mask is expanded together with ``positions`` when needed. """ if positions.ndim == 1: positions = positions.unsqueeze(0) elif positions.ndim != 2: raise ValueError(f"Expected positions tensor with shape (P,) or (B, P), got {tuple(positions.shape)}") if positions.shape[-1] != seq_len: raise ValueError(f"Position length {positions.shape[-1]} does not match sequence length {seq_len}") if positions.shape[0] not in (1, batch_size): raise ValueError(f"Position batch size {positions.shape[0]} must be 1 or match batch size {batch_size}") if positions.shape[0] == 1 and batch_size != 1: positions = positions.expand(batch_size, -1) if position_mask is None: return positions, None if position_mask.ndim == 1: position_mask = position_mask.unsqueeze(0) elif position_mask.ndim != 2: raise ValueError(f"Expected position_mask with shape (P,) or (B, P), got {tuple(position_mask.shape)}") if position_mask.shape[-1] != seq_len: raise ValueError(f"Position mask length {position_mask.shape[-1]} does not match sequence length {seq_len}") if position_mask.shape[0] not in (1, batch_size): raise ValueError( f"Position mask batch size {position_mask.shape[0]} must be 1 or match batch size {batch_size}" ) if position_mask.shape[0] == 1 and batch_size != 1: position_mask = position_mask.expand(batch_size, -1) return positions, position_mask
[docs] class AdditivePosMultiheadAttention(PositionalAttentionBase): r"""Wrap :class:`torch.nn.MultiheadAttention` and add positional features to Q/K only. A learned table maps coordinates to an additive update of the query and key streams, .. math:: \mathbf{q}' = \mathbf{q} + E_\theta(\mathbf{P}_Q), \qquad \mathbf{k}' = \mathbf{k} + E_\theta(\mathbf{P}_K), \qquad \mathbf{v}' = \mathbf{v}. The values are never position-modulated. When ``E_\theta`` is the identity/ no-op branch, this is exactly ordinary multi-head attention. Shape ----- - ``query``, ``key``, ``value``: ``(B, T, D)``. - ``q_positions``, ``k_positions``: ``(P,)`` or ``(B, P)``. - ``q_position_mask``, ``k_position_mask``: boolean masks with the same shape as the corresponding position tensor. - Returns: the attention output and, optionally, attention weights. Attributes: ---------- pos_emb: Learnable table with shape ``(max_len, D)`` storing the absolute positional embeddings. attn: Internal :class:`torch.nn.MultiheadAttention` backend. embed_dim: Model width ``D``. num_heads: Number of attention heads. """ def __init__( self, embed_dim: int, num_heads: int, *, max_len: int, dropout: float = 0.0, bias: bool = True, device=None, dtype=None, ) -> None: r"""Initialize additive positional attention with a wrapped multi-head backend. Args: embed_dim (:class:`int`): Model width ``D``. num_heads (:class:`int`): Number of attention heads. max_len (:class:`int`): Maximum supported sequence length. dropout (:class:`float`): Dropout probability on attention weights. Default: 0.0. bias (:class:`bool`): If ``True``, adds learnable input and output projection biases. Default: ``True``. device (:class:`torch.device`, optional): Parameter factory options. dtype (:class:`torch.dtype`, optional): Parameter factory options. """ super().__init__() if max_len <= 0: raise ValueError(f"max_len must be positive, got {max_len}") self.attn = torch.nn.MultiheadAttention( embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, device=device, dtype=dtype, ) self.embed_dim = embed_dim self.max_len = max_len self.num_heads = num_heads self.pos_emb = torch.nn.Parameter(torch.zeros(max_len, embed_dim, device=device, dtype=dtype))
[docs] def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *, q_positions: torch.Tensor | None = None, k_positions: torch.Tensor | None = None, q_position_mask: torch.Tensor | None = None, k_position_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, need_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: r"""Add the positional update to query and key before attention. The values are passed through unchanged. Shape ----- - ``query``, ``key``, ``value``: see :class:`PositionalAttentionBase`. - Returns: ``(output, attn_weights)`` from the wrapped :class:`torch.nn.MultiheadAttention`. """ query = query + self._position_update(query, q_positions, q_position_mask) key = key + self._position_update(key, k_positions, k_position_mask) return self.attn( query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, is_causal=is_causal, )
def _position_update( self, x: torch.Tensor, positions: torch.Tensor | None, position_mask: torch.Tensor | None, ) -> torch.Tensor: batch_size, seq_len, _ = x.shape positions = self._positions_or_arange(positions, seq_len=seq_len, device=x.device) positions, position_mask = self._normalize_positions( positions, position_mask, batch_size=batch_size, seq_len=seq_len, ) encoded_positions = positions.masked_fill(~position_mask, 0) if position_mask is not None else positions pos_emb = self.pos_emb[encoded_positions.long()] if pos_emb.ndim == 2: expected_shape = (seq_len, self.embed_dim) if pos_emb.shape != expected_shape: raise ValueError(f"Expected positional embedding shape {expected_shape}, got {tuple(pos_emb.shape)}") pos_emb = pos_emb.unsqueeze(0) if batch_size != 1: pos_emb = pos_emb.expand(batch_size, -1, -1) else: expected_shape = (batch_size, seq_len, self.embed_dim) if pos_emb.ndim != 3 or pos_emb.shape != expected_shape: raise ValueError(f"Expected positional embedding shape {expected_shape}, got {tuple(pos_emb.shape)}") if position_mask is not None: pos_emb = pos_emb * position_mask.unsqueeze(-1) return pos_emb
[docs] class eAdditivePosMultiheadAttention(eModule, PositionalAttentionBase): r"""Equivariant additive positional attention with invariant query/key updates.""" def __init__( self, in_rep: Representation, num_heads: int, *, max_len: int, dropout: float = 0.0, bias: bool = True, device=None, dtype=None, init_scheme: str | None = "xavier_normal", ) -> None: super().__init__() if not isinstance(max_len, int) or max_len <= 0: raise ValueError(f"max_len must be a positive integer, got {max_len}") self.in_rep, self.out_rep = in_rep, in_rep self.embed_dim = in_rep.size self.max_len = max_len self.num_heads = num_heads self.attn = eMultiheadAttention( in_rep=in_rep, num_heads=num_heads, dropout=dropout, bias=bias, device=device, dtype=dtype, init_scheme=init_scheme, ) self.register_buffer("invariant_projector", invariant_orthogonal_projector(in_rep)) self.pos_emb = torch.nn.Parameter(torch.zeros(max_len, in_rep.size, device=device, dtype=dtype)) def forward( # noqa: D102 self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *, q_positions: torch.Tensor | None = None, k_positions: torch.Tensor | None = None, q_position_mask: torch.Tensor | None = None, k_position_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, need_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: query = query + self._position_update(query, q_positions, q_position_mask) key = key + self._position_update(key, k_positions, k_position_mask) return self.attn( query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, is_causal=is_causal, ) def _position_update( self, x: torch.Tensor, positions: torch.Tensor | None, position_mask: torch.Tensor | None, ) -> torch.Tensor: batch_size, seq_len, _ = x.shape positions = self._positions_or_arange(positions, seq_len=seq_len, device=x.device) positions, position_mask = self._normalize_positions( positions, position_mask, batch_size=batch_size, seq_len=seq_len, ) encoded_positions = positions.masked_fill(~position_mask, 0) if position_mask is not None else positions pos_emb = self.pos_emb[encoded_positions.long()] if pos_emb.ndim == 2: expected_shape = (seq_len, self.embed_dim) if pos_emb.shape != expected_shape: raise ValueError(f"Expected positional embedding shape {expected_shape}, got {tuple(pos_emb.shape)}") pos_emb = pos_emb.unsqueeze(0) if batch_size != 1: pos_emb = pos_emb.expand(batch_size, -1, -1) else: expected_shape = (batch_size, seq_len, self.embed_dim) if pos_emb.ndim != 3 or pos_emb.shape != expected_shape: raise ValueError(f"Expected positional embedding shape {expected_shape}, got {tuple(pos_emb.shape)}") pos_emb = torch.einsum( "ij,...j->...i", self.invariant_projector.to(device=pos_emb.device, dtype=pos_emb.dtype), pos_emb, ) if position_mask is not None: pos_emb = pos_emb * position_mask.unsqueeze(-1) return pos_emb @torch.no_grad() def reset_parameters(self, scheme="xavier_uniform") -> None: # noqa: D102 self.attn.reset_parameters(scheme=scheme) self.pos_emb.zero_() def invalidate_cache(self) -> None: # noqa: D102 self.attn.invalidate_cache()
[docs] class AdditiveRelMultiheadAttention(PositionalAttentionBase): r"""Wrap :class:`torch.nn.MultiheadAttention` and subtract a relative-position bias from the logits. A learned table maps pairwise relative distances to an additive correction of the attention scores, .. math:: \mathbf{A}_{ij} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_h}} - \phi_\theta(\mathbf{P}_{Q, i} - \mathbf{P}_{K, j}), while the value stream remains unchanged. Since the correction depends only on relative offsets, the attention module is time-translation equivariant. Shape ----- - ``query``, ``key``, ``value``: ``(B, T, D)``. - ``q_positions``, ``k_positions``: ``(P,)`` or ``(B, P)``. - ``q_position_mask``, ``k_position_mask``: boolean masks with the same shape as the corresponding position tensor. - Returns: the attention output and, optionally, attention weights. Attributes: ---------- rel_bias: Learnable table with shape ``(2 * max_distance + 1,)`` storing the scalar bias for each clipped relative offset. attn: Internal :class:`torch.nn.MultiheadAttention` backend. embed_dim: Model width ``D``. num_heads: Number of attention heads. """ def __init__( self, embed_dim: int, num_heads: int, *, max_distance: int, dropout: float = 0.0, bias: bool = True, device=None, dtype=None, ) -> None: r"""Initialize relative-bias attention with a wrapped multi-head backend. Args: embed_dim (:class:`int`): Model width ``D``. num_heads (:class:`int`): Number of attention heads. max_distance (:class:`int`): Maximum relative distance represented explicitly before clipping. dropout (:class:`float`): Dropout probability on attention weights. Default: 0.0. bias (:class:`bool`): If ``True``, adds learnable input and output projection biases. Default: ``True``. device (:class:`torch.device`, optional): Parameter factory options. dtype (:class:`torch.dtype`, optional): Parameter factory options. """ super().__init__() if max_distance < 0: raise ValueError(f"max_distance must be non-negative, got {max_distance}") self.attn = torch.nn.MultiheadAttention( embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, device=device, dtype=dtype, ) self.embed_dim = embed_dim self.max_distance = max_distance self.num_heads = num_heads self.rel_bias = torch.nn.Parameter(torch.zeros(2 * max_distance + 1, device=device, dtype=dtype))
[docs] def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *, q_positions: torch.Tensor | None = None, k_positions: torch.Tensor | None = None, q_position_mask: torch.Tensor | None = None, k_position_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, need_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: r"""Subtract the relative-position bias from the attention logits. Shape ----- - ``query``, ``key``, ``value``: see :class:`PositionalAttentionBase`. - Returns: ``(output, attn_weights)`` from the wrapped :class:`torch.nn.MultiheadAttention`. """ batch_size = query.shape[0] tgt_len = query.shape[1] src_len = key.shape[1] q_positions = self._positions_or_arange(q_positions, seq_len=tgt_len, device=query.device) k_positions = self._positions_or_arange(k_positions, seq_len=src_len, device=key.device) rel_bias = self._relative_bias( q_positions, k_positions, q_position_mask, k_position_mask, batch_size=batch_size, tgt_len=tgt_len, src_len=src_len, target_dtype=query.dtype, ) attn_mask = F._canonical_mask( mask=attn_mask, mask_name="attn_mask", other_type=None, other_name="", target_type=query.dtype, check_other=False, ) if rel_bias is not None: attn_mask = -rel_bias if attn_mask is None else attn_mask - rel_bias return self.attn( query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, is_causal=is_causal, )
def _relative_bias( self, q_positions: torch.Tensor | None, k_positions: torch.Tensor | None, q_position_mask: torch.Tensor | None, k_position_mask: torch.Tensor | None, *, batch_size: int, tgt_len: int, src_len: int, target_dtype: torch.dtype, ) -> torch.Tensor | None: if q_positions is None or k_positions is None: return None positions_shared_across_batch = self._positions_are_shared_across_batch( q_positions, q_position_mask ) and self._positions_are_shared_across_batch(k_positions, k_position_mask) q_positions, q_position_mask = self._normalize_positions( q_positions, q_position_mask, batch_size=batch_size, seq_len=tgt_len, ) k_positions, k_position_mask = self._normalize_positions( k_positions, k_position_mask, batch_size=batch_size, seq_len=src_len, ) if positions_shared_across_batch: q_positions = q_positions[0] k_positions = k_positions[0] rel_positions = q_positions.unsqueeze(-1) - k_positions.unsqueeze(-2) pair_mask = None if q_position_mask is not None or k_position_mask is not None: q_valid = ( torch.ones(tgt_len, device=rel_positions.device, dtype=torch.bool) if q_position_mask is None else q_position_mask[0] ) k_valid = ( torch.ones(src_len, device=rel_positions.device, dtype=torch.bool) if k_position_mask is None else k_position_mask[0] ) pair_mask = q_valid.unsqueeze(-1) & k_valid.unsqueeze(-2) else: rel_positions = q_positions.unsqueeze(-1) - k_positions.unsqueeze(-2) pair_mask = None if q_position_mask is not None or k_position_mask is not None: q_valid = ( torch.ones(batch_size, tgt_len, device=rel_positions.device, dtype=torch.bool) if q_position_mask is None else q_position_mask ) k_valid = ( torch.ones(batch_size, src_len, device=rel_positions.device, dtype=torch.bool) if k_position_mask is None else k_position_mask ) pair_mask = q_valid.unsqueeze(-1) & k_valid.unsqueeze(-2) rel_bias = self._relative_bias_values(rel_positions).to(dtype=target_dtype) if rel_bias.ndim == 3: rel_bias = rel_bias.repeat_interleave(self.num_heads, dim=0) elif not positions_shared_across_batch: rel_bias = rel_bias.unsqueeze(0).expand(batch_size * self.num_heads, -1, -1) if pair_mask is None: return rel_bias if rel_bias.ndim == 2: return rel_bias * pair_mask.to(dtype=rel_bias.dtype) if pair_mask.ndim == 2: pair_mask = pair_mask.unsqueeze(0).expand(batch_size * self.num_heads, -1, -1) else: pair_mask = pair_mask.repeat_interleave(self.num_heads, dim=0) return rel_bias * pair_mask.to(dtype=rel_bias.dtype) @staticmethod def _positions_are_shared_across_batch( positions: torch.Tensor, position_mask: torch.Tensor | None, ) -> bool: shared_positions = positions.ndim == 1 or (positions.ndim == 2 and positions.shape[0] == 1) if position_mask is None: return shared_positions shared_mask = position_mask.ndim == 1 or (position_mask.ndim == 2 and position_mask.shape[0] == 1) return shared_positions and shared_mask def _relative_bias_values(self, rel_positions: torch.Tensor) -> torch.Tensor: clipped_positions = rel_positions.clamp(-self.max_distance, self.max_distance).long() + self.max_distance return self.rel_bias[clipped_positions]
[docs] class eAdditiveRelMultiheadAttention(eModule, PositionalAttentionBase): r"""Equivariant relative-bias attention with an equivariant attention backend.""" def __init__( self, in_rep: Representation, num_heads: int, *, max_distance: int, dropout: float = 0.0, bias: bool = True, device=None, dtype=None, init_scheme: str | None = "xavier_normal", ) -> None: super().__init__() if not isinstance(max_distance, int) or max_distance <= 0: raise ValueError(f"max_distance must be a positive integer, got {max_distance}") self.in_rep, self.out_rep = in_rep, in_rep self.embed_dim = in_rep.size self.max_distance = max_distance self.num_heads = num_heads self.attn = eMultiheadAttention( in_rep=in_rep, num_heads=num_heads, dropout=dropout, bias=bias, device=device, dtype=dtype, init_scheme=init_scheme, ) self.rel_bias = torch.nn.Parameter(torch.zeros(2 * max_distance + 1, device=device, dtype=dtype)) def forward( # noqa: D102 self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *, q_positions: torch.Tensor | None = None, k_positions: torch.Tensor | None = None, q_position_mask: torch.Tensor | None = None, k_position_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, need_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: batch_size = query.shape[0] tgt_len = query.shape[1] src_len = key.shape[1] q_positions = self._positions_or_arange(q_positions, seq_len=tgt_len, device=query.device) k_positions = self._positions_or_arange(k_positions, seq_len=src_len, device=key.device) rel_bias = self._relative_bias( q_positions, k_positions, q_position_mask, k_position_mask, batch_size=batch_size, tgt_len=tgt_len, src_len=src_len, target_dtype=query.dtype, ) attn_mask = F._canonical_mask( mask=attn_mask, mask_name="attn_mask", other_type=None, other_name="", target_type=query.dtype, check_other=False, ) if rel_bias is not None: attn_mask = -rel_bias if attn_mask is None else attn_mask - rel_bias return self.attn( query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, is_causal=is_causal, ) def _relative_bias( self, q_positions: torch.Tensor | None, k_positions: torch.Tensor | None, q_position_mask: torch.Tensor | None, k_position_mask: torch.Tensor | None, *, batch_size: int, tgt_len: int, src_len: int, target_dtype: torch.dtype, ) -> torch.Tensor | None: if q_positions is None or k_positions is None: return None positions_shared_across_batch = self._positions_are_shared_across_batch( q_positions, q_position_mask ) and self._positions_are_shared_across_batch(k_positions, k_position_mask) q_positions, q_position_mask = self._normalize_positions( q_positions, q_position_mask, batch_size=batch_size, seq_len=tgt_len, ) k_positions, k_position_mask = self._normalize_positions( k_positions, k_position_mask, batch_size=batch_size, seq_len=src_len, ) if positions_shared_across_batch: q_positions = q_positions[0] k_positions = k_positions[0] rel_positions = q_positions.unsqueeze(-1) - k_positions.unsqueeze(-2) pair_mask = None if q_position_mask is not None or k_position_mask is not None: q_valid = ( torch.ones(tgt_len, device=rel_positions.device, dtype=torch.bool) if q_position_mask is None else q_position_mask[0] ) k_valid = ( torch.ones(src_len, device=rel_positions.device, dtype=torch.bool) if k_position_mask is None else k_position_mask[0] ) pair_mask = q_valid.unsqueeze(-1) & k_valid.unsqueeze(-2) else: rel_positions = q_positions.unsqueeze(-1) - k_positions.unsqueeze(-2) pair_mask = None if q_position_mask is not None or k_position_mask is not None: q_valid = ( torch.ones(batch_size, tgt_len, device=rel_positions.device, dtype=torch.bool) if q_position_mask is None else q_position_mask ) k_valid = ( torch.ones(batch_size, src_len, device=rel_positions.device, dtype=torch.bool) if k_position_mask is None else k_position_mask ) pair_mask = q_valid.unsqueeze(-1) & k_valid.unsqueeze(-2) rel_bias = self._relative_bias_values(rel_positions).to(dtype=target_dtype) if rel_bias.ndim == 3: rel_bias = rel_bias.repeat_interleave(self.num_heads, dim=0) elif not positions_shared_across_batch: rel_bias = rel_bias.unsqueeze(0).expand(batch_size * self.num_heads, -1, -1) if pair_mask is None: return rel_bias if rel_bias.ndim == 2: return rel_bias * pair_mask.to(dtype=rel_bias.dtype) if pair_mask.ndim == 2: pair_mask = pair_mask.unsqueeze(0).expand(batch_size * self.num_heads, -1, -1) else: pair_mask = pair_mask.repeat_interleave(self.num_heads, dim=0) return rel_bias * pair_mask.to(dtype=rel_bias.dtype) @staticmethod def _positions_are_shared_across_batch( positions: torch.Tensor, position_mask: torch.Tensor | None, ) -> bool: shared_positions = positions.ndim == 1 or (positions.ndim == 2 and positions.shape[0] == 1) if position_mask is None: return shared_positions shared_mask = position_mask.ndim == 1 or (position_mask.ndim == 2 and position_mask.shape[0] == 1) return shared_positions and shared_mask def _relative_bias_values(self, rel_positions: torch.Tensor) -> torch.Tensor: clipped_positions = rel_positions.clamp(-self.max_distance, self.max_distance).long() + self.max_distance return self.rel_bias[clipped_positions] @torch.no_grad() def reset_parameters(self, scheme="xavier_uniform") -> None: # noqa: D102 self.attn.reset_parameters(scheme=scheme) self.rel_bias.zero_() def invalidate_cache(self) -> None: # noqa: D102 self.attn.invalidate_cache()
[docs] class RoPEMultiheadAttention(PositionalAttentionBase): r"""Multi-head attention with rotary position embeddings applied to Q and K. Query, key, and value tensors are first projected into per-head features. The entire head embedding is then rotated in position space, block by block in 2D pairs: .. math:: \mathbf{q}_r' = R(\mathbf{P}_Q)\,\mathbf{q}_r, \qquad \mathbf{k}_r' = R(\mathbf{P}_K)\,\mathbf{k}_r, \qquad \mathbf{v}' = \mathbf{v}, where ``R`` is the block-wise 2D rotation induced by the sine/cosine tables. The attention scores are then .. math:: \mathbf{A} = \operatorname{softmax}\left( \frac{\mathbf{q}'\mathbf{k}'^\top}{\sqrt{d_h}} + \mathbf{M}\right), \qquad \mathbf{O} = \mathbf{A}\mathbf{v}'. Shape ----- - ``query``, ``key``, ``value``: ``(B, T, D)``. - ``q_positions``, ``k_positions``: ``(P,)`` or ``(B, P)``. - ``q_position_mask``, ``k_position_mask``: boolean masks with the same shape as the corresponding position tensor. - Returns: ``(output, attn_weights)`` where ``output`` has the same leading layout as the input and ``attn_weights`` is ``(B, T_q, T_k)`` when requested. Attributes: ---------- embed_dim: Total feature width ``D``. num_heads: Number of attention heads. head_dim: Width of each head, ``head_dim = embed_dim / num_heads``. dropout: Dropout probability applied to attention weights during training. q_proj, k_proj, v_proj, out_proj: Learnable linear projections used to form queries, keys, values, and the final output. rotary_emb: Helper module that builds the sine and cosine tables used by RoPE. """ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, rope_base: float = 10000.0, device=None, dtype=None, ) -> None: r"""Initialize the rotary-position attention block. Args: embed_dim (:class:`int`): Model width ``D``. num_heads (:class:`int`): Number of attention heads. dropout (:class:`float`): Dropout probability on attention weights. Default: 0.0. bias (:class:`bool`): If ``True``, adds learnable input and output projection biases. Default: ``True``. rope_base (:class:`float`): Frequency base used to build the rotary spectrum. Default: ``10000.0``. device (:class:`torch.device`, optional): Parameter factory options. dtype (:class:`torch.dtype`, optional): Parameter factory options. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if embed_dim <= 0: raise ValueError(f"embed_dim must be positive, got {embed_dim}") if num_heads <= 0: raise ValueError(f"num_heads must be positive, got {num_heads}") if embed_dim % num_heads != 0: raise ValueError( "embed_dim={embed_dim} must be divisible by num_heads={num_heads} " "so each head has an integer dimension".format(embed_dim=embed_dim, num_heads=num_heads) ) head_dim = embed_dim // num_heads if head_dim % 2 != 0: raise ValueError( f"head_dim={head_dim} must be even because RoPE rotates the entire head embedding in 2D pairs. " f"Choose embed_dim and num_heads so embed_dim / num_heads is even." ) self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = head_dim self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.v_proj = torch.nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.rotary_emb = RotaryEmbedding(head_dim, base=rope_base, device=device, dtype=dtype)
[docs] def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *, q_positions: torch.Tensor | None = None, k_positions: torch.Tensor | None = None, q_position_mask: torch.Tensor | None = None, k_position_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, need_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: r"""Project inputs, apply RoPE to each head, then compute attention. Shape ----- - ``query``, ``key``, ``value``: see :class:`PositionalAttentionBase`. - Returns: ``(output, attn_weights)`` with the same leading layout as the input and optional attention weights when requested. """ q_positions = self._positions_or_arange(q_positions, seq_len=query.shape[1], device=query.device) k_positions = self._positions_or_arange(k_positions, seq_len=key.shape[1], device=key.device) q = self._split_heads(self.q_proj(query)) k = self._split_heads(self.k_proj(key)) v = self._split_heads(self.v_proj(value)) q = self.rotary_emb.apply_rope(q, positions=q_positions, position_mask=q_position_mask) k = self.rotary_emb.apply_rope(k, positions=k_positions, position_mask=k_position_mask) key_padding_mask = F._canonical_mask( mask=key_padding_mask, mask_name="key_padding_mask", other_type=F._none_or_dtype(attn_mask), other_name="attn_mask", target_type=query.dtype, ) attn_mask = F._canonical_mask( mask=attn_mask, mask_name="attn_mask", other_type=None, other_name="", target_type=query.dtype, check_other=False, ) merged_mask = self._merge_masks(attn_mask, key_padding_mask, query.shape[0], k.shape[2]) if need_weights: output, attn_weights = self._attention_with_weights(q, k, v, merged_mask, is_causal=is_causal) else: output = F.scaled_dot_product_attention( q, k, v, attn_mask=merged_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) attn_weights = None output = self.out_proj(self._merge_heads(output)) return output, attn_weights
def _split_heads(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, _ = x.shape return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: batch_size, _, seq_len, _ = x.shape return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) def _merge_masks( self, attn_mask: torch.Tensor | None, key_padding_mask: torch.Tensor | None, batch_size: int, src_len: int, ) -> torch.Tensor | None: merged_mask = attn_mask if merged_mask is not None: if merged_mask.ndim == 2: merged_mask = merged_mask.unsqueeze(0).unsqueeze(0) elif merged_mask.ndim == 3: if merged_mask.shape[0] == batch_size * self.num_heads: merged_mask = merged_mask.view(batch_size, self.num_heads, merged_mask.shape[-2], src_len) else: merged_mask = merged_mask.unsqueeze(1) elif merged_mask.ndim != 4: raise ValueError(f"Unsupported attn_mask shape {tuple(merged_mask.shape)}") if key_padding_mask is not None: padding_mask = key_padding_mask[:, None, None, :] merged_mask = padding_mask if merged_mask is None else merged_mask + padding_mask return merged_mask def _attention_with_weights( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor | None, *, is_causal: bool, ) -> tuple[torch.Tensor, torch.Tensor]: attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim**-0.5) if is_causal: tgt_len, src_len = attn_scores.shape[-2:] causal_mask = torch.ones(tgt_len, src_len, device=attn_scores.device, dtype=torch.bool).triu(1) attn_scores = attn_scores.masked_fill(causal_mask, float("-inf")) if attn_mask is not None: attn_scores = attn_scores + attn_mask attn_weights = torch.softmax(attn_scores, dim=-1) if self.dropout > 0.0 and self.training: attn_weights = torch.dropout(attn_weights, self.dropout, train=True) return torch.matmul(attn_weights, v), attn_weights.mean(dim=1)
[docs] class RotaryEmbedding(torch.nn.Module): r"""Precompute the cosine and sine tables used by rotary embeddings. Shape ----- - ``positions``: ``(P,)`` or ``(B, P)``. - Returns: ``(cos, sin)`` with shape ``(P, dim / 2)`` or ``(B, P, dim / 2)``. Attributes: ---------- dim: Number of channels rotated by RoPE. base: Frequency base used to build the inverse frequency spectrum. inv_freq: Buffer containing the inverse frequencies used to generate the tables. """ def __init__( self, dim: int, base: float = 10000.0, device=None, dtype=None, ) -> None: r"""Initialize the RoPE table builder. Args: dim (:class:`int`): Number of channels rotated by RoPE. base (:class:`float`): Frequency base used to build the inverse frequency spectrum. device (:class:`torch.device`, optional): Buffer factory options. dtype (:class:`torch.dtype`, optional): Buffer factory options. """ super().__init__() if dim <= 0 or dim % 2 != 0: raise ValueError(f"Rotary dim must be a positive even integer, got {dim}") factory_kwargs = {"device": device, "dtype": dtype} inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, **factory_kwargs) / dim)) self.dim = dim self.base = base self.register_buffer("inv_freq", inv_freq, persistent=False)
[docs] def forward(self, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return the RoPE cosine and sine tables for the provided positions.""" angles = positions.unsqueeze(-1) * self.inv_freq return angles.cos(), angles.sin()
[docs] def apply_rope( self, x: torch.Tensor, positions: torch.Tensor, position_mask: torch.Tensor | None = None, ) -> torch.Tensor: r"""Apply rotary position embeddings to the entire head embedding of ``x``. Shape ----- - ``x``: ``(B, H, T, D)``. - ``positions``: ``(P,)`` or ``(B, P)``. - ``position_mask``: optional boolean mask with the same layout as ``positions``. - Returns: ``x`` with every head channel rotated in place. """ if x.ndim != 4: raise ValueError(f"Expected x with shape (B, H, T, D), got {tuple(x.shape)}") batch_size, _, seq_len, head_dim = x.shape positions, position_mask = PositionalAttentionBase._normalize_positions( positions, position_mask, batch_size=batch_size, seq_len=seq_len, ) if position_mask is not None: positions = positions.masked_fill(~position_mask, 0) if self.dim != head_dim: raise ValueError(f"RotaryEmbedding.dim={self.dim} must match the head dimension {head_dim}") cos, sin = self(positions) cos = cos.unsqueeze(1).repeat_interleave(2, dim=-1).to(dtype=x.dtype) sin = sin.unsqueeze(1).repeat_interleave(2, dim=-1).to(dtype=x.dtype) # Apply the 2D rotation block by block across the entire head embedding. x_even = x[..., ::2] x_odd = x[..., 1::2] x_rotated = x * cos + torch.stack((-x_odd, x_even), dim=-1).flatten(start_dim=-2) * sin if position_mask is not None: # Leave padded positions unchanged so masking does not inject a phase rotation. keep_mask = position_mask.unsqueeze(1).unsqueeze(-1) x = torch.where(keep_mask, x_rotated, x) else: x = x_rotated return x