AdditivePosMultiheadAttention#
- class AdditivePosMultiheadAttention(embed_dim, num_heads, *, max_len, dropout=0.0, bias=True, device=None, dtype=None)[source]#
Bases:
PositionalAttentionBaseWrap
torch.nn.MultiheadAttentionand add positional features to Q/K only.A learned table maps coordinates to an additive update of the query and key streams,
\[\mathbf{q}' = \mathbf{q} + E_\theta(\mathbf{P}_Q), \qquad \mathbf{k}' = \mathbf{k} + E_\theta(\mathbf{P}_K), \qquad \mathbf{v}' = \mathbf{v}.\]The values are never position-modulated. When
E_\thetais the identity/ no-op branch, this is exactly ordinary multi-head attention.Shape#
query,key,value:(B, T, D).q_positions,k_positions:(P,)or(B, P).q_position_mask,k_position_mask: boolean masks with the same shapeas the corresponding position tensor.
Returns: the attention output and, optionally, attention weights.
Attributes:#
- pos_emb:
Learnable table with shape
(max_len, D)storing the absolute positional embeddings.- attn:
Internal
torch.nn.MultiheadAttentionbackend.- embed_dim:
Model width
D.- num_heads:
Number of attention heads.
Initialize additive positional attention with a wrapped multi-head backend.
- type embed_dim:
- param embed_dim:
Model width
D.- type embed_dim:
- type num_heads:
- param num_heads:
Number of attention heads.
- type num_heads:
- type max_len:
- param max_len:
Maximum supported sequence length.
- type max_len:
- type dropout:
- param dropout:
Dropout probability on attention weights. Default: 0.0.
- type dropout:
- type bias:
- param bias:
If
True, adds learnable input and output projection biases. Default:True.- type bias:
- type device:
torch.device, optional- param device:
Parameter factory options.
- type device:
torch.device, optional- type dtype:
torch.dtype, optional- param dtype:
Parameter factory options.
- type dtype:
torch.dtype, optional
- forward(query, key, value, *, q_positions=None, k_positions=None, q_position_mask=None, k_position_mask=None, attn_mask=None, key_padding_mask=None, need_weights=False, is_causal=False)[source]#
Add the positional update to query and key before attention.
The values are passed through unchanged.
- Return type:
- Parameters:
Shape#
query,key,value: seePositionalAttentionBase.Returns:
(output, attn_weights)from the wrappedtorch.nn.MultiheadAttention.