AdditivePosMultiheadAttention#

class AdditivePosMultiheadAttention(embed_dim, num_heads, *, max_len, dropout=0.0, bias=True, device=None, dtype=None)[source]#

Bases: PositionalAttentionBase

Wrap torch.nn.MultiheadAttention and add positional features to Q/K only.

A learned table maps coordinates to an additive update of the query and key streams,

\[\mathbf{q}' = \mathbf{q} + E_\theta(\mathbf{P}_Q), \qquad \mathbf{k}' = \mathbf{k} + E_\theta(\mathbf{P}_K), \qquad \mathbf{v}' = \mathbf{v}.\]

The values are never position-modulated. When E_\theta is the identity/ no-op branch, this is exactly ordinary multi-head attention.

Shape#

  • query, key, value: (B, T, D).

  • q_positions, k_positions: (P,) or (B, P).

  • q_position_mask, k_position_mask: boolean masks with the same shape

    as the corresponding position tensor.

  • Returns: the attention output and, optionally, attention weights.

Attributes:#

pos_emb:

Learnable table with shape (max_len, D) storing the absolute positional embeddings.

attn:

Internal torch.nn.MultiheadAttention backend.

embed_dim:

Model width D.

num_heads:

Number of attention heads.

Initialize additive positional attention with a wrapped multi-head backend.

type embed_dim:

int

param embed_dim:

Model width D.

type embed_dim:

int

type num_heads:

int

param num_heads:

Number of attention heads.

type num_heads:

int

type max_len:

int

param max_len:

Maximum supported sequence length.

type max_len:

int

type dropout:

float

param dropout:

Dropout probability on attention weights. Default: 0.0.

type dropout:

float

type bias:

bool

param bias:

If True, adds learnable input and output projection biases. Default: True.

type bias:

bool

type device:

torch.device, optional

param device:

Parameter factory options.

type device:

torch.device, optional

type dtype:

torch.dtype, optional

param dtype:

Parameter factory options.

type dtype:

torch.dtype, optional

forward(query, key, value, *, q_positions=None, k_positions=None, q_position_mask=None, k_position_mask=None, attn_mask=None, key_padding_mask=None, need_weights=False, is_causal=False)[source]#

Add the positional update to query and key before attention.

The values are passed through unchanged.

Return type:

tuple[Tensor, Tensor | None]

Parameters:

Shape#

Parameters: