AdditiveRelMultiheadAttention#
- class AdditiveRelMultiheadAttention(embed_dim, num_heads, *, max_distance, dropout=0.0, bias=True, device=None, dtype=None)[source]#
Bases:
PositionalAttentionBaseWrap
torch.nn.MultiheadAttentionand subtract a relative-position bias from the logits.A learned table maps pairwise relative distances to an additive correction of the attention scores,
\[\mathbf{A}_{ij} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_h}} - \phi_\theta(\mathbf{P}_{Q, i} - \mathbf{P}_{K, j}),\]while the value stream remains unchanged. Since the correction depends only on relative offsets, the attention module is time-translation equivariant.
Shape#
query,key,value:(B, T, D).q_positions,k_positions:(P,)or(B, P).q_position_mask,k_position_mask: boolean masks with the same shapeas the corresponding position tensor.
Returns: the attention output and, optionally, attention weights.
Attributes:#
- rel_bias:
Learnable table with shape
(2 * max_distance + 1,)storing the scalar bias for each clipped relative offset.- attn:
Internal
torch.nn.MultiheadAttentionbackend.- embed_dim:
Model width
D.- num_heads:
Number of attention heads.
Initialize relative-bias attention with a wrapped multi-head backend.
- type embed_dim:
- param embed_dim:
Model width
D.- type embed_dim:
- type num_heads:
- param num_heads:
Number of attention heads.
- type num_heads:
- type max_distance:
- param max_distance:
Maximum relative distance represented explicitly before clipping.
- type max_distance:
- type dropout:
- param dropout:
Dropout probability on attention weights. Default: 0.0.
- type dropout:
- type bias:
- param bias:
If
True, adds learnable input and output projection biases. Default:True.- type bias:
- type device:
torch.device, optional- param device:
Parameter factory options.
- type device:
torch.device, optional- type dtype:
torch.dtype, optional- param dtype:
Parameter factory options.
- type dtype:
torch.dtype, optional
- forward(query, key, value, *, q_positions=None, k_positions=None, q_position_mask=None, k_position_mask=None, attn_mask=None, key_padding_mask=None, need_weights=False, is_causal=False)[source]#
Subtract the relative-position bias from the attention logits.
- Return type:
- Parameters:
Shape#
query,key,value: seePositionalAttentionBase.Returns:
(output, attn_weights)from the wrappedtorch.nn.MultiheadAttention.