Source code for symm_learning.nn.normalization

from __future__ import annotations

import logging
from typing import Literal

import torch
from escnn.group import Representation

import symm_learning.stats
from symm_learning.linalg import irrep_radii
from symm_learning.nn.linear import eAffine
from symm_learning.representation_theory import direct_sum


[docs] class eRMSNorm(torch.nn.Module): r"""Root-mean-square normalization with :math:`\mathbb{G}`-equivariant affine map. For :math:`\mathbf{x}\in\mathcal{X}` with :math:`D=\dim(\rho_{\mathcal{X}})`, define .. math:: \hat{\mathbf{x}} = \frac{\mathbf{x}}{\sqrt{\frac{1}{D} \|\mathbf{x}\|^2 + \varepsilon}}, \qquad \mathbf{y} = \alpha \odot \hat{\mathbf{x}} + \mathbf{\beta} where the second step is implemented by :class:`~symm_learning.nn.linear.eAffine` (per-irrep scaling and optional invariant bias). Equivariance: .. math:: \operatorname{eRMSNorm}(\rho_{\mathcal{X}}(g)\mathbf{x}) = \rho_{\mathcal{X}}(g)\operatorname{eRMSNorm}(\mathbf{x}), \quad \forall g\in\mathbb{G}, because the RMS factor is invariant and :class:`~symm_learning.nn.linear.eAffine` commutes with :math:`\rho_{\mathcal{X}}`. Args: in_rep (:class:`~escnn.group.Representation`): Description of the feature space :math:`\rho_{\text{in}}`. eps (:class:`float`): Numerical stabilizer added inside the RMS computation. equiv_affine (:class:`bool`): If ``True``, apply a symmetry-preserving :class:`~symm_learning.nn.linear.eAffine` after normalization. device, dtype: Optional tensor factory kwargs passed to the affine parameters. init_scheme (Literal["identity", "random"] | None): Initialization scheme forwarded to :meth:`~symm_learning.nn.linear.eAffine.reset_parameters`. Set to ``None`` to skip initialization (useful when loading checkpoints). Shape: - Input: ``(..., in_rep.size)`` - Output: same shape Note: The normalization factor is a single scalar per sample, so the operation commutes with any matrix representing the group action defined by ``in_rep``. """ def __init__( self, in_rep: Representation, eps: float = 1e-6, equiv_affine: bool = True, device=None, dtype=None, init_scheme: Literal["identity", "random"] | None = "identity", ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.in_rep, self.out_rep = in_rep, in_rep if equiv_affine: self.affine = eAffine(in_rep, bias=False).to(**factory_kwargs) if init_scheme is not None: self.affine.reset_parameters(init_scheme) self.eps = eps self.normalized_shape = (in_rep.size,) if init_scheme is not None: self.reset_parameters(init_scheme)
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: """Normalize by a single RMS scalar and (optionally) apply equivariant affine. Args: input: Tensor shaped ``(..., in_rep.size)``. Returns: Tensor with identical shape, RMS-normalized and possibly transformed by :class:`~symm_learning.nn.linear.eAffine`. """ assert input.shape[-1] == self.in_rep.size, f"Expected (...,{self.in_rep.size}), got {input.shape}" rms_input = torch.sqrt(self.eps + torch.mean(input.pow(2), dim=-1, keepdim=True)) normalized = input / rms_input if hasattr(self, "affine"): normalized = self.affine(normalized) return normalized
[docs] def reset_parameters(self, scheme: Literal["identity", "random"] = "identity") -> None: """(Re)initialize the optional affine transform using the provided scheme.""" if hasattr(self, "affine"): self.affine.reset_parameters(scheme)
[docs] class eLayerNorm(torch.nn.Module): r"""Equivariant Layer Normalization. Given :math:`\mathbf{x}\in\mathcal{X}`, we first move to the irrep-spectral basis :math:`\hat{\mathbf{x}} = \mathbf{Q}^{-1}\mathbf{x}`, compute one variance scalar per irreducible block via :func:`~symm_learning.linalg.irrep_radii`, and normalize each block uniformly: .. math:: \hat{\mathbf{y}} = \frac{\hat{\mathbf{x}}}{\sqrt{\mathbf{\sigma}^{2} + \varepsilon}}, \qquad \mathbf{y} = Q\hat{\mathbf{y}}. The layer is equivariant: .. math:: \rho_{\text{in}}(g) \mathbf{y} = \text{LayerNorm}(\rho_{\text{in}}(g) \mathbf{x}) since the statistics are computed per irreducible subspace (which are preserved by the group action). When ``equiv_affine=True`` the learnable affine step is performed directly in the spectral basis using the per-irrep scale/bias provided by :class:`~symm_learning.nn.linear.eAffine`. Args: in_rep (:class:`~escnn.group.Representation`): description of the feature space :math:`\rho_{\text{in}}`. eps (:class:`float`): numerical stabilizer added to each variance. equiv_affine (:class:`bool`): if ``True``, applies an :class:`~symm_learning.nn.linear.eAffine` in spectral space. bias (:class:`bool`): whether the affine term includes invariant biases (only used if ``equiv_affine``). device, dtype: optional tensor factory kwargs. Note: This layer appears to generate numerical instability when used in equivariant transformer blocks. Use eRMSNorm instead in such cases. """ r"""Symmetry-preserving LayerNorm: .. math:: \mathbf{y} = Q ( (Q^{-1} \mathbf{x}) \odot \alpha + \mathbf{\beta} ) """ def __init__( self, in_rep: Representation, eps: float = 1e-6, equiv_affine: bool = True, bias: bool = True, device=None, dtype=None, init_scheme: Literal["identity", "random"] | None = "identity", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_rep, self.out_rep = in_rep, in_rep # We require to transition to/from the irrep-spectral basis to compute the normalization and affine transform self.register_buffer("Q", torch.tensor(self.in_rep.change_of_basis, dtype=torch.get_default_dtype())) self.register_buffer("Q_inv", torch.tensor(self.in_rep.change_of_basis_inv, dtype=torch.get_default_dtype())) # Only works for (..., in_rep.size) inputs with normalization over the last dimension self.normalized_shape = (in_rep.size,) self.eps = eps self.equiv_affine = equiv_affine dims = torch.tensor( [self.in_rep.group.irrep(*irrep_id).size for irrep_id in self.in_rep.irreps], dtype=torch.long, ) self.register_buffer("irrep_dims", dims) self.register_buffer("irrep_indices", torch.repeat_interleave(torch.arange(len(dims), dtype=torch.long), dims)) if self.equiv_affine: self.affine = eAffine(in_rep, bias=bias).to(**factory_kwargs) self.reset_parameters(init_scheme) def reset_parameters(self, scheme: Literal["identity", "random"] = "identity") -> None: # noqa: D102 if self.equiv_affine: self.affine.reset_parameters(scheme)
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Normalize per irreducible block and (optionally) apply the spectral affine transform.""" assert input.shape[-1] == self.in_rep.size, f"Expected (...,{self.in_rep.size}), got {input.shape}" radii = irrep_radii(input, rep=self.in_rep) # (..., num_irreps) dims = self.irrep_dims.to(radii.device, radii.dtype) var_irreps = radii.pow(2) / dims var_broadcasted = var_irreps[..., self.irrep_indices.to(var_irreps.device)] x_spec = torch.einsum("ij,...j->...i", self.Q_inv, input) x_spec = x_spec / torch.sqrt(var_broadcasted + self.eps) if self.equiv_affine: spectral_scale, spectral_bias = self.affine.broadcast_spectral_scale_and_bias( self.affine.scale_dof, self.affine.bias_dof ) x_spec = x_spec * spectral_scale.view(*([1] * (x_spec.ndim - 1)), -1) if spectral_bias is not None: x_spec = x_spec + spectral_bias.view(*([1] * (x_spec.ndim - 1)), -1) normalized = torch.einsum("ij,...j->...i", self.Q, x_spec) return normalized
def extra_repr(self) -> str: # noqa: D102 return "{normalized_shape}, eps={eps}, affine={equiv_affine}".format(**self.__dict__)
[docs] class eBatchNorm1d(torch.nn.Module): r"""Symmetry-aware Batch Normalization over the representation dimension. The mean and variance are computed with :func:`~symm_learning.stats.var_mean`, enforcing that each irreducible subspace shares a single variance scalar. The optional affine parameters are implemented via :class:`~symm_learning.nn.linear.eAffine` to preserve equivariance. The layer satisfies: .. math:: \rho_{\text{in}}(g) \mathbf{y} = \text{BatchNorm}(\rho_{\text{in}}(g) \mathbf{x}) Args: in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` describing the feature space. eps: Numerical stabilizer added to the variance. momentum: Momentum for exponential moving averages. affine: If ``True``, apply a symmetry-preserving affine transform. track_running_stats: If ``True``, keep running mean/variance buffers. Shape: - Input: ``(..., in_rep.size)`` - Output: same shape """ def __init__( self, in_rep: Representation, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ): super().__init__() if not isinstance(in_rep, Representation): raise TypeError(f"in_rep must be a Representation, got {type(in_rep)}") self.in_rep = in_rep self.out_rep = in_rep self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats self._rep_x = in_rep if self.track_running_stats: self.register_buffer("running_mean", torch.zeros(in_rep.size)) self.register_buffer("running_var", torch.ones(in_rep.size)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) if self.affine: self.affine_transform = eAffine(in_rep, bias=True)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D102 """Normalize using symmetry-constrained batch statistics and optional equivariant affine map.""" assert x.shape[-1] == self.in_rep.size, f"Expected (..., {self.in_rep.size}), got {x.shape}" x_flat = x.reshape(-1, self.in_rep.size) var_batch, mean_batch = symm_learning.stats.var_mean(x_flat, rep_x=self._rep_x) if self.track_running_stats: if self.training: with torch.no_grad(): self.running_mean.mul_(1 - self.momentum).add_(mean_batch, alpha=self.momentum) self.running_var.mul_(1 - self.momentum).add_(var_batch, alpha=self.momentum) self.num_batches_tracked += 1 mean, var = self.running_mean, self.running_var else: mean, var = mean_batch, var_batch view_shape = [1] * (x.ndim - 1) + [-1] mean = mean.view(*view_shape) var = var.view(*view_shape) y = (x - mean) / torch.sqrt(var + self.eps) if self.affine: y = self.affine_transform(y) return y
def evaluate_output_shape(self, input_shape): # noqa: D102 return input_shape def extra_repr(self) -> str: # noqa: D102 return ( f"in_rep: {self.in_rep}, affine: {self.affine}, track_running_stats: {self.track_running_stats} " f"eps: {self.eps}, momentum: {self.momentum}" )
[docs] def check_equivariance(self, atol=1e-5, rtol=1e-5): """Check equivariance using random group elements.""" was_training = self.training batch_size = 50 self.train() # Warm up running statistics for _ in range(5): x = torch.randn(batch_size, self.in_rep.size) _ = self(x) self.eval() x_batch = torch.randn(batch_size, self.in_rep.size) G = self.in_rep.group for _ in range(10): g = G.sample() if g == G.identity: continue rho_g = torch.tensor(self.in_rep(g), dtype=x_batch.dtype, device=x_batch.device) gx_batch = x_batch @ rho_g.T var, mean = symm_learning.stats.var_mean(x_batch, rep_x=self.in_rep) g_var, g_mean = symm_learning.stats.var_mean(gx_batch, rep_x=self.in_rep) assert torch.allclose(mean, g_mean, atol=1e-4, rtol=1e-4), f"Mean {mean} != {g_mean}" assert torch.allclose(var, g_var, atol=1e-4, rtol=1e-4), f"Var {var} != {g_var}" y = self(x_batch) g_y = self(gx_batch) g_y_gt = y @ rho_g.T assert torch.allclose(g_y, g_y_gt, atol=1e-5, rtol=1e-5), ( f"Output {g_y} does not match the expected output {g_y_gt} for group element {g}" ) self.train(was_training) return None
if __name__ == "__main__": import sys import types from pathlib import Path from escnn.group import CyclicGroup, Icosahedral repo_root = Path(__file__).resolve().parents[2] test_dir = repo_root / "test" sys.path.insert(0, str(repo_root)) test_pkg = sys.modules.get("test") test_paths = [str(path) for path in getattr(test_pkg, "__path__", [])] if test_pkg else [] if str(test_dir) not in test_paths: test_pkg = types.ModuleType("test") test_pkg.__path__ = [str(test_dir)] sys.modules["test"] = test_pkg from symm_learning.utils import bytes_to_mb, module_device_memory, module_memory from test.utils import benchmark, benchmark_eval_forward # G = CyclicGroup(2) G = Icosahedral() m = 2 eps = 1e-6 in_rep = direct_sum([G.regular_representation] * m) rms_norm = torch.nn.RMSNorm(in_rep.size, eps=eps, elementwise_affine=True) eq_rms_norm = eRMSNorm(in_rep, eps=eps, equiv_affine=True) batch_size = 1024 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rms_norm = rms_norm.to(device) eq_rms_norm = eq_rms_norm.to(device) print(f"Device: {device}") x = torch.randn(batch_size, in_rep.size, device=device) def run_forward(mod): # noqa: D103 return mod(x) modules_to_benchmark = [ ("RMSNorm", rms_norm), ("eRMSNorm", eq_rms_norm), ] results = [] for name, module in modules_to_benchmark: def forward_fn(mod=module): # noqa: D103 return run_forward(mod) train_mem, non_train_mem = module_memory(module) gpu_alloc, gpu_peak = module_device_memory(module) eval_fwd_mean, eval_fwd_std = benchmark_eval_forward(module, forward_fn) (fwd_mean, fwd_std), (bwd_mean, bwd_std) = benchmark(module, forward_fn) results.append( { "name": name, "fwd_eval_mean": eval_fwd_mean, "fwd_eval_std": eval_fwd_std, "fwd_mean": fwd_mean, "fwd_std": fwd_std, "bwd_mean": bwd_mean, "bwd_std": bwd_std, "total_time": fwd_mean + bwd_mean, "train_mem": train_mem, "non_train_mem": non_train_mem, "gpu_mem": gpu_alloc, "gpu_peak": gpu_peak, } ) name_width = 20 header = ( f"{'Layer':<{name_width}} {'Forward eval (ms)':>18} {'Forward (ms)':>18} {'Backward (ms)':>18} " f"{'Total (ms)':>15} " f"{'Trainable MB':>15} {'Non-train MB':>15} {'Total MB':>12} {'GPU Alloc MB':>15} {'GPU Peak MB':>15}" ) separator = "-" * len(header) print(f"\nBenchmark results per {batch_size}-sample batch") print(separator) print(header) print(separator) for res in results: fwd_eval_str = f"{res['fwd_eval_mean']:.3f} +/- {res['fwd_eval_std']:.3f}" fwd_str = f"{res['fwd_mean']:.3f} +/- {res['fwd_std']:.3f}" bwd_str = f"{res['bwd_mean']:.3f} +/- {res['bwd_std']:.3f}" total_mb = res["train_mem"] + res["non_train_mem"] gpu_alloc_mb = bytes_to_mb(res["gpu_mem"]) gpu_peak_mb = bytes_to_mb(res["gpu_peak"]) print( f"{res['name']:<{name_width}} {fwd_eval_str:>18} {fwd_str:>18} {bwd_str:>18} " f"{res['total_time']:>15.3f} {bytes_to_mb(res['train_mem']):>15.3f} " f"{bytes_to_mb(res['non_train_mem']):>15.3f} {bytes_to_mb(total_mb):>12.3f} " f"{gpu_alloc_mb:>15.3f} {gpu_peak_mb:>15.3f}" ) print(separator)