eAffine#

class eAffine(in_rep, bias=True, learnable=True, init_scheme='identity')[source]#

Bases: Module

Equivariant affine map with per-irrep scales and invariant bias.

Let \(\mathbf{x}\in\mathcal{X}\) with representation

\[\rho_{\mathcal{X}} = \mathbf{Q}\left( \bigoplus_{k\in[1,n_{\text{iso}}]} \bigoplus_{i\in[1,n_k]} \hat{\rho}_k \right)\mathbf{Q}^T.\]

This module applies

\[\mathbf{y} = \mathbf{Q}\,\mathbf{D}_{\alpha}\,\mathbf{Q}^T\mathbf{x} + \mathbf{b},\]

where \(\mathbf{D}_{\alpha}\) is diagonal in irrep-spectral basis and constant over dimensions of each irrep copy (\(\alpha_{k,i}\)), while \(\mathbf{b}\in\mathrm{Fix}(\rho_{\mathcal{X}})\) (trivial block only). Therefore:

\[\rho_{\mathcal{X}}(g)\mathbf{y} = \operatorname{eAffine}\!\left(\rho_{\mathcal{X}}(g)\mathbf{x}\right) \quad \forall g\in\mathbb{G}.\]

When learnable=True these DoFs are trainable parameters. When learnable=False, scale_dof and bias_dof are provided at call-time (FiLM style).

Parameters:
  • in_rep (Representation) – Representation describing the input/output space \(\rho_{\text{in}}\).

  • bias (bool) – include invariant biases when the trivial irrep is present. Default: True.

  • learnable (bool) – if False, no parameters are registered and scale_dof/bias_dof must be passed at call time. Default: True.

  • init_scheme (Optional[Literal['identity', 'random']]) – initialization for the learnable DoFs ("identity" or "random"). Set to None to skip init (e.g. when loading weights). Ignored when learnable=False.

Shape:
  • Input: (..., D) with D = in_rep.size.

  • scale_dof (optional): (..., num_scale_dof) where n_irreps is the number of irreps in in_rep.

  • bias_dof (optional): (..., num_bias_dof) when bias=True.

  • Output: (..., D).

Note

Runtime behavior depends on mode. In training mode (model.train()), the affine map is recomputed each forward pass. In inference mode (model.eval()) and with learnable=True, the dense affine map \(\mathbf{Q}\mathbf{D}_{\alpha}\mathbf{Q}^T\) (and optional invariant bias) is cached and reused until parameters change or invalidate_cache() is called, which is faster. Unlike eLinear, this module is not a strict symmetry-agnostic drop-in affine block, because parameters are irrep-structured and may also be provided externally (FiLM-style via scale_dof/bias_dof).

rep_x#

Representation \(\rho_{\mathcal{X}}\) of the feature space.

Type:

Representation

Q#

Change-of-basis matrix to the irrep-spectral basis.

Type:

Tensor

Q_inv#

Inverse change-of-basis matrix from irrep-spectral basis.

Type:

Tensor

bias_module#

Optional module handling the invariant bias.

Type:

InvariantBias | None

broadcast_spectral_scale_and_bias(scale_dof, bias_dof=None, input_shape=None)[source]#

Return spectral scale and bias from provided DoFs.

Parameters:
  • scale_dof (Tensor) – Per-irrep scale coefficients shaped (..., num_scale_dof).

  • bias_dof (Tensor | None) – Invariant-bias coefficients shaped (..., num_bias_dof) when bias=True.

  • input_shape (Size | None) – Shape of the input tensor when learnable=False. Used to validate DoF shapes.

Returns:

Pair (spectral_scale, spectral_bias) where spectral_scale is shaped (..., rep_x.size) and spectral_bias is shaped (..., rep_x.size) (or None when bias=False).

Return type:

tuple[torch.Tensor, torch.Tensor | None]

expand_affine()[source]#

Expand the per-irrep scales into an affine matrix in the original basis.

Return type:

Tensor

forward(input, scale_dof=None, bias_dof=None)[source]#

Apply the equivariant affine transform.

When learnable=False scale_dof (and bias_dof if bias=True) must be provided.

Parameters:
  • input (Tensor) – Tensor \(\mathbf{x}\) whose last dimension matches in_rep.size.

  • scale_dof (Tensor | None) – Optional per-irrep scaling DoFs (length num_scale_dof). Required when learnable=False with leading dims matching input.shape[:-1].

  • bias_dof (Tensor | None) – Optional invariant-bias DoFs (length num_bias_dof). Required when learnable=False and bias=True with leading dims matching input.shape[:-1].

Return type:

Tensor

invalidate_cache()[source]#

Clear cached affine map so it is recomputed on next use.

Return type:

None

load_state_dict(state_dict, strict=True)[source]#

Load parameters and invalidate cached affine expansion.

Parameters:

strict (bool)

property num_bias_dof: int#

Number of bias degrees of freedom (length of bias_dof) for invariant irreps.

property num_scale_dof: int#

Number of per-irrep scaling degrees of freedom (length of scale_dof).

reset_parameters(scheme='identity')[source]#

Initialize spectral scale/bias DoFs.

Parameters:

scheme (Literal['identity', 'random']) – "identity" sets all scales to one and bias to zero; "random" samples both uniformly in [-1, 1]. Set to None when loading checkpoints to skip reinit.

Return type:

None

train(mode=True)[source]#

Switch mode and keep cached affine expansion consistent.

Parameters:

mode (bool)