PositionalAttentionBase#
- class PositionalAttentionBase(*args, **kwargs)[source]#
-
Abstract interface for attention blocks with explicit positional branches.
The convention in this module is that positional information is provided as a function, defined by a submodule, acting on the query and key streams only. Values are left unchanged:
\[ ilde{\mathbf{q}} = \mathbf{q} + \phi_\theta(\mathbf{P}_Q), \qquad ilde{\mathbf{k}} = \mathbf{k} + \phi_\theta(\mathbf{P}_K), \qquad ilde{\mathbf{v}} = \mathbf{v}.\]A standard multi-head attention operator is then applied to \((\tilde{\mathbf{q}}, \tilde{\mathbf{k}}, \tilde{\mathbf{v}})\). If the positional branch is the identity/no-op map, the module reduces to
torch.nn.MultiheadAttention.Concrete implementations may ignore any positional arguments they do not use, but the forward signature stays stable so encoder and decoder layers can call all attention backends uniformly.
Shape#
query,key,value:(B, P, D).q_positions,k_positions:(P,)or(B, P). When omitted, they default totorch.arange(P)for the corresponding sequence.q_position_mask,k_position_mask:(P,)or(B, P)boolean masks.attn_mask: any attention mask layout accepted by
key_padding_mask:(B, S)boolean or additive padding mask.- Returns: attention output with the same leading layout as the input, plus
optional attention weights.
Attributes:#
This base class does not define any storage beyond the standard
torch.nn.Modulestate. Concrete subclasses define their own positional encoder and attention parameters.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- abstractmethod forward(query, key, value, *, q_positions=None, k_positions=None, q_position_mask=None, k_position_mask=None, attn_mask=None, key_padding_mask=None, need_weights=False, is_causal=False)[source]#
Apply attention with optional explicit position metadata.
- Return type:
- Parameters:
Shape#
query,key,value: seePositionalAttentionBase.q_positions,k_positions: position coordinates for the query andkey tokens. They may differ in cross-attention.
- Returns:
(output, attn_weights)whereattn_weightsisNone unless
need_weights=True.
- Returns: