eMultiheadAttention#

class eMultiheadAttention(in_rep, num_heads, dropout=0.0, bias=True, batch_first=False, add_bias_kv=False, add_zero_attn=False, device=None, dtype=None, init_scheme='xavier_normal')[source]#

Bases: MultiheadAttention

Drop-in replacement for torch.nn.MultiheadAttention that preserves G-equivariance.

This module keeps the runtime logic of PyTorch’s implementation untouched: we still rely on the packed in_proj_weight / in_proj_bias for computing queries, keys, and values, and the internal attention kernel (including mask handling, dropouts, and softmax) is exactly the stock MultiheadAttention behavior.

Equivariance is achieved by constraining every linear projection involved in the attention block:

  • the input projection [Q; K; V] = W_in @ x is treated as a single map from the input representation to three stacked copies of a regular-representation block that aligns with the requested num_heads (enforced via CommutingConstraint);

  • the optional stacked bias is projected onto the invariant subspace of that same block via InvariantConstraint;

  • the output projection out_proj is constrained to commute with the group action so that the concatenated value vectors are mapped back into the original feature space equivariantly.

Additionally, we restrict num_heads to divide the number of regular-representation copies present in the input feature space to avoid splitting irreducible subspaces across heads.

Initialize the equivariant multihead attention.

Parameters:
  • in_rep (Representation) – Representation \(\rho_{\text{in}}\) of the input/output space.

  • num_heads (int) – Number of parallel attention heads.

  • dropout (float) – Dropout probability on attention weights. Default: 0.0.

  • bias (bool) – If True, adds learnable input and output projection biases. Default: True.

  • batch_first (bool) – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False.

  • add_bias_kv (bool) – Not supported. Must be False.

  • add_zero_attn (bool) – Not supported. Must be False.

  • device (torch.device, optional) – Parameter factory options.

  • dtype (torch.dtype, optional) – Parameter factory options.

  • init_scheme (str | None, optional) – Initialization scheme for the equivariant linear layers. Default: "xavier_normal".

reset_parameters(scheme='xavier_uniform')[source]#

Overload parent method to take into account equivariance constraints.

Return type:

None