eRMSNorm#

class eRMSNorm(in_rep, eps=1e-06, equiv_affine=True, device=None, dtype=None, init_scheme='identity')[source]#

Bases: Module

Root-mean-square normalization with \(\mathbb{G}\)-equivariant affine map.

For \(\mathbf{x}\in\mathcal{X}\) with \(D=\dim(\rho_{\mathcal{X}})\), define

\[\hat{\mathbf{x}} = \frac{\mathbf{x}}{\sqrt{\frac{1}{D} \|\mathbf{x}\|^2 + \varepsilon}}, \qquad \mathbf{y} = \alpha \odot \hat{\mathbf{x}} + \mathbf{\beta}\]

where the second step is implemented by eAffine (per-irrep scaling and optional invariant bias).

Equivariance:

\[\operatorname{eRMSNorm}(\rho_{\mathcal{X}}(g)\mathbf{x}) = \rho_{\mathcal{X}}(g)\operatorname{eRMSNorm}(\mathbf{x}), \quad \forall g\in\mathbb{G},\]

because the RMS factor is invariant and eAffine commutes with \(\rho_{\mathcal{X}}\).

Parameters:
  • in_rep (Representation) – Description of the feature space \(\rho_{\text{in}}\).

  • eps (float) – Numerical stabilizer added inside the RMS computation.

  • equiv_affine (bool) – If True, apply a symmetry-preserving eAffine after normalization.

  • device – Optional tensor factory kwargs passed to the affine parameters.

  • dtype – Optional tensor factory kwargs passed to the affine parameters.

  • init_scheme (Literal["identity", "random"] | None) – Initialization scheme forwarded to reset_parameters(). Set to None to skip initialization (useful when loading checkpoints).

Shape:
  • Input: (..., in_rep.size)

  • Output: same shape

Note

The normalization factor is a single scalar per sample, so the operation commutes with any matrix representing the group action defined by in_rep.

forward(input)[source]#

Normalize by a single RMS scalar and (optionally) apply equivariant affine.

Parameters:

input (Tensor) – Tensor shaped (..., in_rep.size).

Return type:

Tensor

Returns:

Tensor with identical shape, RMS-normalized and possibly transformed by eAffine.

reset_parameters(scheme='identity')[source]#

(Re)initialize the optional affine transform using the provided scheme.

Return type:

None

Parameters:

scheme (Literal['identity', 'random'])