eMultiheadAttention#
- class eMultiheadAttention(in_rep, num_heads, dropout=0.0, bias=True, batch_first=False, add_bias_kv=False, add_zero_attn=False, device=None, dtype=None, init_scheme='xavier_normal')[source]#
Bases:
MultiheadAttentionDrop-in replacement for
torch.nn.MultiheadAttentionthat preserves G-equivariance.This module keeps the runtime logic of PyTorch’s implementation untouched: we still rely on the packed
in_proj_weight/in_proj_biasfor computing queries, keys, and values, and the internal attention kernel (including mask handling, dropouts, and softmax) is exactly the stock MultiheadAttention behavior.Equivariance is achieved by constraining every linear projection involved in the attention block:
the input projection
[Q; K; V] = W_in @ xis treated as a single map from the input representation to three stacked copies of a regular-representation block that aligns with the requestednum_heads(enforced viaCommutingConstraint);the optional stacked bias is projected onto the invariant subspace of that same block via
InvariantConstraint;the output projection
out_projis constrained to commute with the group action so that the concatenated value vectors are mapped back into the original feature space equivariantly.
Additionally, we restrict
num_headsto divide the number of regular-representation copies present in the input feature space to avoid splitting irreducible subspaces across heads.Initialize the equivariant multihead attention.
- Parameters:
in_rep (
Representation) – Representation \(\rho_{\text{in}}\) of the input/output space.num_heads (
int) – Number of parallel attention heads.dropout (
float) – Dropout probability on attention weights. Default: 0.0.bias (
bool) – IfTrue, adds learnable input and output projection biases. Default:True.batch_first (
bool) – IfTrue, then the input and output tensors are provided as (batch, seq, feature). Default:False.add_bias_kv (
bool) – Not supported. Must beFalse.add_zero_attn (
bool) – Not supported. Must beFalse.device (
torch.device, optional) – Parameter factory options.dtype (
torch.dtype, optional) – Parameter factory options.init_scheme (
str|None, optional) – Initialization scheme for the equivariant linear layers. Default:"xavier_normal".