AdditiveRelMultiheadAttention#

class AdditiveRelMultiheadAttention(embed_dim, num_heads, *, max_distance, dropout=0.0, bias=True, device=None, dtype=None)[source]#

Bases: PositionalAttentionBase

Wrap torch.nn.MultiheadAttention and subtract a relative-position bias from the logits.

A learned table maps pairwise relative distances to an additive correction of the attention scores,

\[\mathbf{A}_{ij} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_h}} - \phi_\theta(\mathbf{P}_{Q, i} - \mathbf{P}_{K, j}),\]

while the value stream remains unchanged. Since the correction depends only on relative offsets, the attention module is time-translation equivariant.

Shape#

  • query, key, value: (B, T, D).

  • q_positions, k_positions: (P,) or (B, P).

  • q_position_mask, k_position_mask: boolean masks with the same shape

    as the corresponding position tensor.

  • Returns: the attention output and, optionally, attention weights.

Attributes:#

rel_bias:

Learnable table with shape (2 * max_distance + 1,) storing the scalar bias for each clipped relative offset.

attn:

Internal torch.nn.MultiheadAttention backend.

embed_dim:

Model width D.

num_heads:

Number of attention heads.

Initialize relative-bias attention with a wrapped multi-head backend.

type embed_dim:

int

param embed_dim:

Model width D.

type embed_dim:

int

type num_heads:

int

param num_heads:

Number of attention heads.

type num_heads:

int

type max_distance:

int

param max_distance:

Maximum relative distance represented explicitly before clipping.

type max_distance:

int

type dropout:

float

param dropout:

Dropout probability on attention weights. Default: 0.0.

type dropout:

float

type bias:

bool

param bias:

If True, adds learnable input and output projection biases. Default: True.

type bias:

bool

type device:

torch.device, optional

param device:

Parameter factory options.

type device:

torch.device, optional

type dtype:

torch.dtype, optional

param dtype:

Parameter factory options.

type dtype:

torch.dtype, optional

forward(query, key, value, *, q_positions=None, k_positions=None, q_position_mask=None, k_position_mask=None, attn_mask=None, key_padding_mask=None, need_weights=False, is_causal=False)[source]#

Subtract the relative-position bias from the attention logits.

Return type:

tuple[Tensor, Tensor | None]

Parameters:

Shape#

Parameters: