RoPEMultiheadAttention#

class RoPEMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, rope_base=10000.0, device=None, dtype=None)[source]#

Bases: PositionalAttentionBase

Multi-head attention with rotary position embeddings applied to Q and K.

Query, key, and value tensors are first projected into per-head features. The entire head embedding is then rotated in position space, block by block in 2D pairs:

\[\mathbf{q}_r' = R(\mathbf{P}_Q)\,\mathbf{q}_r, \qquad \mathbf{k}_r' = R(\mathbf{P}_K)\,\mathbf{k}_r, \qquad \mathbf{v}' = \mathbf{v},\]

where R is the block-wise 2D rotation induced by the sine/cosine tables. The attention scores are then

\[\mathbf{A} = \operatorname{softmax}\left( \frac{\mathbf{q}'\mathbf{k}'^\top}{\sqrt{d_h}} + \mathbf{M}\right), \qquad \mathbf{O} = \mathbf{A}\mathbf{v}'.\]

Shape#

  • query, key, value: (B, T, D).

  • q_positions, k_positions: (P,) or (B, P).

  • q_position_mask, k_position_mask: boolean masks with the same shape

    as the corresponding position tensor.

  • Returns: (output, attn_weights) where output has the same leading

    layout as the input and attn_weights is (B, T_q, T_k) when requested.

Attributes:#

embed_dim:

Total feature width D.

num_heads:

Number of attention heads.

head_dim:

Width of each head, head_dim = embed_dim / num_heads.

dropout:

Dropout probability applied to attention weights during training.

q_proj, k_proj, v_proj, out_proj:

Learnable linear projections used to form queries, keys, values, and the final output.

rotary_emb:

Helper module that builds the sine and cosine tables used by RoPE.

Initialize the rotary-position attention block.

type embed_dim:

int

param embed_dim:

Model width D.

type embed_dim:

int

type num_heads:

int

param num_heads:

Number of attention heads.

type num_heads:

int

type dropout:

float

param dropout:

Dropout probability on attention weights. Default: 0.0.

type dropout:

float

type bias:

bool

param bias:

If True, adds learnable input and output projection biases. Default: True.

type bias:

bool

type rope_base:

float

param rope_base:

Frequency base used to build the rotary spectrum. Default: 10000.0.

type rope_base:

float

type device:

torch.device, optional

param device:

Parameter factory options.

type device:

torch.device, optional

type dtype:

torch.dtype, optional

param dtype:

Parameter factory options.

type dtype:

torch.dtype, optional

forward(query, key, value, *, q_positions=None, k_positions=None, q_position_mask=None, k_position_mask=None, attn_mask=None, key_padding_mask=None, need_weights=False, is_causal=False)[source]#

Project inputs, apply RoPE to each head, then compute attention.

Return type:

tuple[Tensor, Tensor | None]

Parameters:

Shape#

  • query, key, value: see PositionalAttentionBase.

  • Returns: (output, attn_weights) with the same leading layout as the input and optional attention weights when requested.

Parameters: