Source code for symm_learning.nn.parametrizations

import copy
import logging

import torch
from escnn.group import Representation

from symm_learning.linalg import invariant_orthogonal_projector
from symm_learning.representation_theory import GroupHomomorphismBasis, direct_sum
from symm_learning.utils import bytes_to_mb, check_equivariance, module_device_memory, module_memory

logger = logging.getLogger(__name__)


[docs] class InvariantConstraint(torch.nn.Module): r"""Orthogonally project vectors onto :math:`\mathrm{Fix}(\rho)`. For representation :math:`\rho_{\mathcal{X}}`, this parametrization enforces .. math:: \mathbf{b} \in \mathrm{Fix}(\rho_{\mathcal{X}}) = \{\mathbf{v}\in\mathcal{X}: \rho_{\mathcal{X}}(g)\mathbf{v}=\mathbf{v},\ \forall g\in\mathbb{G}\}, by applying :math:`\mathbf{P}_{\mathrm{inv}}` from :func:`~symm_learning.linalg.invariant_orthogonal_projector`. Note: Runtime behavior depends on mode. In training mode (``model.train()``), the projection is recomputed each forward pass. In inference mode (``model.eval()``), the projected output is cached for the same unchanged input tensor (same object identity and version counter), which is faster. With the cache active, the operation is equivalent to a symmetry-agnostic fixed linear map :math:`\mathbf{b}\mapsto\mathbf{P}_{\mathrm{inv}}\mathbf{b}`. Attributes: rep (:class:`~escnn.group.Representation`): Representation :math:`\rho` whose action defines invariance, dimension ``rep.size``. inv_projector (:class:`~torch.Tensor`): Orthogonal projector of shape ``(rep.size, rep.size)`` onto the fixed subspace :math:`\mathrm{Fix}(\rho)`. """ def __init__(self, rep: Representation): """Precompute the invariant projector for the supplied representation.""" super().__init__() self.rep = rep self.register_buffer("inv_projector", invariant_orthogonal_projector(rep)) self.register_buffer("_bias", None, persistent=False) self._cached_input_id = None self._cached_input_version = None def _cache_is_valid(self, b: torch.Tensor) -> bool: if self._bias is None: return False if id(b) != self._cached_input_id: return False version = getattr(b, "_version", None) return version is not None and version == self._cached_input_version def _update_cache(self, b: torch.Tensor, bias: torch.Tensor) -> None: version = getattr(b, "_version", None) if version is None: self.invalidate_cache() return self._bias = bias self._cached_input_id = id(b) self._cached_input_version = version
[docs] def forward(self, b: torch.Tensor) -> torch.Tensor: """Project b onto the invariant subspace using the precomputed projector.""" if not self.training and self._cache_is_valid(b): return self._bias bias = torch.mv(self.inv_projector, b) if not self.training: self._update_cache(b, bias) return bias
[docs] def right_inverse(self, tensor: torch.Tensor) -> torch.Tensor: """Return a parameter tensor whose projection equals ``tensor``.""" return tensor
[docs] def invalidate_cache(self) -> None: """Clear cached projection so it is recomputed on next use.""" self._bias = None self._cached_input_id = None self._cached_input_version = None
def _apply(self, fn): super()._apply(fn) self.invalidate_cache() return self
[docs] def load_state_dict(self, state_dict, strict: bool = True): # noqa: D102 """Load parameters and clear cached projected bias.""" result = super().load_state_dict(state_dict, strict) self.invalidate_cache() return result
[docs] class CommutingConstraint(torch.nn.Module): r"""Orthogonal projection onto :math:`\operatorname{Hom}_{\mathbb{G}}(\rho_{\text{in}},\rho_{\text{out}})`. For a dense weight :math:`\mathbf{W}\in\mathbb{R}^{D_{\text{out}}\times D_{\text{in}}}`, this module returns :math:`\Pi_{\mathrm{Hom}}(\mathbf{W})`, the Frobenius-orthogonal projection onto .. math:: \operatorname{Hom}_{\mathbb{G}}(\rho_{\text{in}},\rho_{\text{out}}) = \{\mathbf{A}: \rho_{\text{out}}(g)\mathbf{A}=\mathbf{A}\rho_{\text{in}}(g),\ \forall g\in\mathbb{G}\}. The basis and projection are handled by :class:`~symm_learning.representation_theory.GroupHomomorphismBasis`, using the :ref:`isotypic decomposition <isotypic-decomposition-example>` (:func:`~symm_learning.representation_theory.isotypic_decomp_rep`) blockwise. Args: in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}` of size ``in_rep.size``. out_rep (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\text{out}}` of size ``out_rep.size``. basis_expansion (:class:`str`, optional): Strategy used to realize the basis (``"memory_heavy"`` or ``"isotypic_expansion"``). Note: Runtime behavior depends on mode. In training mode (``model.train()``), the projection is recomputed each forward pass. In inference mode (``model.eval()``), the projected matrix is cached for the same unchanged input tensor (same object identity and version counter), which is faster. With the cache active, as a parametrization of :class:`~torch.nn.Linear`, the forward path is equivalent to a symmetry-agnostic standard linear layer with a fixed projected dense weight. Attributes: homo_basis (:class:`~symm_learning.representation_theory.GroupHomomorphismBasis`): Basis generator carrying the :ref:`isotypic decomposition <isotypic-decomposition-example>` and block metadata for :math:`\operatorname{Hom}_\mathbb{G}(\rho_{\text{in}}, \rho_{\text{out}})`. in_rep / out_rep (:class:`~escnn.group.Representation`): Cached references to the isotypic versions of the supplied representations. """ def __init__(self, in_rep: Representation, out_rep: Representation, basis_expansion: str = "isotypic_expansion"): super().__init__() self.homo_basis = GroupHomomorphismBasis(in_rep, out_rep, basis_expansion=basis_expansion) self.in_rep = self.homo_basis.in_rep self.out_rep = self.homo_basis.out_rep self.register_buffer("_weight", None, persistent=False) self._cached_input_id = None self._cached_input_version = None def _cache_is_valid(self, W: torch.Tensor) -> bool: if self._weight is None: return False if id(W) != self._cached_input_id: return False version = getattr(W, "_version", None) return version is not None and version == self._cached_input_version def _update_cache(self, W: torch.Tensor, W_proj: torch.Tensor) -> None: version = getattr(W, "_version", None) if version is None: self.invalidate_cache() return self._weight = W_proj self._cached_input_id = id(W) self._cached_input_version = version
[docs] def forward(self, W: torch.Tensor) -> torch.Tensor: r"""Project :math:`\mathbf{W}` onto the space of equivariant linear maps. Args: W (:class:`~torch.Tensor`): Dense matrix :math:`\mathbf{W}\in\mathbb{R}^{D_{\mathrm{out}}\times D_{\mathrm{in}}}`. Returns: :class:`~torch.Tensor`: Frobenius-orthogonal projection :math:`\Pi_{\mathrm{Hom}}(\mathbf{W})`, which satisfies :math:`\rho_{\mathrm{out}}(g)\Pi_{\mathrm{Hom}}(\mathbf{W})=\Pi_{\mathrm{Hom}}(\mathbf{W})\rho_{\mathrm{in}}(g)` for all :math:`g\in\mathbb{G}`. """ if not self.training and self._cache_is_valid(W): return self._weight W_proj = self.homo_basis.orthogonal_projection(W) if not self.training: self._update_cache(W, W_proj) return W_proj
[docs] def right_inverse(self, tensor: torch.Tensor) -> torch.Tensor: """Return a pre-image for the parametrization (identity for now).""" return tensor
[docs] def invalidate_cache(self) -> None: """Clear cached projection so it is recomputed on next use.""" self._weight = None self._cached_input_id = None self._cached_input_version = None
def _apply(self, fn): super()._apply(fn) self.invalidate_cache() return self
[docs] def load_state_dict(self, state_dict, strict: bool = True): # noqa: D102 """Load parameters and clear cached projected weight.""" result = super().load_state_dict(state_dict, strict) self.invalidate_cache() return result
if __name__ == "__main__": import sys import types from pathlib import Path import escnn from escnn.group import CyclicGroup, Icosahedral from escnn.gspaces import no_base_space from escnn.nn import FieldType repo_root = Path(__file__).resolve().parents[2] test_dir = repo_root / "test" sys.path.insert(0, str(repo_root)) test_pkg = sys.modules.get("test") test_paths = [str(path) for path in getattr(test_pkg, "__path__", [])] if test_pkg else [] if str(test_dir) not in test_paths: test_pkg = types.ModuleType("test") test_pkg.__path__ = [str(test_dir)] sys.modules["test"] = test_pkg from symm_learning.nn.linear import eAffine, eLinear, impose_linear_equivariance from test.utils import benchmark, benchmark_eval_forward # G = CyclicGroup(2) G = Icosahedral() m = 2 bias = True in_rep = direct_sum([G.regular_representation] * m) out_rep = direct_sum([G.regular_representation] * m * 2) eq_layer_proj_heavy = torch.nn.Linear(in_features=in_rep.size, out_features=out_rep.size, bias=bias) impose_linear_equivariance( lin=eq_layer_proj_heavy, in_rep=in_rep, out_rep=out_rep, basis_expansion_scheme="memory_heavy" ) eq_layer_proj_iso = torch.nn.Linear(in_features=in_rep.size, out_features=out_rep.size, bias=bias) impose_linear_equivariance( lin=eq_layer_proj_iso, in_rep=in_rep, out_rep=out_rep, basis_expansion_scheme="isotypic_expansion" ) # Check the orthogonal projection to Hom_G(rep, rep) works as expected in_type = escnn.nn.FieldType(escnn.gspaces.no_base_space(G), [G.regular_representation] * m) out_type = escnn.nn.FieldType(escnn.gspaces.no_base_space(G), [G.regular_representation] * m * 2) escnn_layer = escnn.nn.Linear(in_type, out_type, bias=bias) W, b = escnn_layer.expand_parameters() def check_projection(layer, label): # noqa: D103 layer.weight = W W_projected = layer.weight assert torch.allclose(W, W_projected, atol=1e-5, rtol=1e-5), f"Max err: {(W - W_projected).abs().max()}" check_equivariance(layer, atol=1e-5, rtol=1e-5, in_rep=in_rep, out_rep=out_rep) print(f"{label} projection equivariance test passed.") W_random = torch.randn_like(W) layer.weight = W_random check_equivariance(layer, atol=1e-5, rtol=1e-5) W_proj = layer.weight layer.weight = W_proj W_proj2 = layer.weight assert torch.allclose(W_proj, W_proj2, atol=1e-5, rtol=1e-5), f"Max err: {(W_proj - W_proj2).abs().max()}" print(f"{label} projection idempotence test passed.") check_projection(eq_layer_proj_heavy, "Memory-heavy") check_projection(eq_layer_proj_iso, "Isotypic-expansion") # ____________________________________________________________________________________ standad_layer = torch.nn.Linear(in_features=in_rep.size, out_features=out_rep.size, bias=bias) eq_layer_iso = eLinear(in_rep, out_rep, bias=bias, basis_expansion_scheme="isotypic_expansion") eq_layer_heavy = eLinear(in_rep, out_rep, bias=bias, basis_expansion_scheme="memory_heavy") e_affine = eAffine(in_rep, bias=bias) in_type = FieldType(no_base_space(G), [in_rep]) out_type = FieldType(no_base_space(G), [out_rep]) escnn_layer = escnn.nn.Linear(in_type, out_type, bias=bias) batch_size = 1024 device = torch.device("cuda") standad_layer = standad_layer.to(device) eq_layer_iso = eq_layer_iso.to(device) eq_layer_heavy = eq_layer_heavy.to(device) eq_layer_proj_heavy = eq_layer_proj_heavy.to(device) eq_layer_proj_iso = eq_layer_proj_iso.to(device) e_affine = e_affine.to(device) escnn_layer = escnn_layer.to(device) print(f"Device: {device}") x = torch.randn(batch_size, in_rep.size, device=device) def run_forward(mod): # noqa: D103 return mod(in_type(x)).tensor if isinstance(mod, escnn.nn.Linear) else mod(x) modules_to_benchmark = [ ("Standard", standad_layer), ("eLinear (iso)", eq_layer_iso), ("eLinear (heavy)", eq_layer_heavy), ("Linear Proj (iso)", eq_layer_proj_iso), ("Linear Proj (heavy)", eq_layer_proj_heavy), ("eAffine", e_affine), ("escnn", escnn_layer), ] results = [] for name, module in modules_to_benchmark: def forward_fn(mod=module): # noqa: D103 return run_forward(mod) train_mem, non_train_mem = module_memory(module) gpu_alloc, gpu_peak = module_device_memory(module) eval_fwd_mean, eval_fwd_std = benchmark_eval_forward(module, forward_fn) (fwd_mean, fwd_std), (bwd_mean, bwd_std) = benchmark(module, forward_fn) results.append( { "name": name, "fwd_eval_mean": eval_fwd_mean, "fwd_eval_std": eval_fwd_std, "fwd_mean": fwd_mean, "fwd_std": fwd_std, "bwd_mean": bwd_mean, "bwd_std": bwd_std, "total_time": fwd_mean + bwd_mean, "train_mem": train_mem, "non_train_mem": non_train_mem, "gpu_mem": gpu_alloc, "gpu_peak": gpu_peak, } ) name_width = 20 header = ( f"{'Layer':<{name_width}} {'Forward eval (ms)':>18} {'Forward (ms)':>18} {'Backward (ms)':>18} " f"{'Total (ms)':>15} " f"{'Trainable MB':>15} {'Non-train MB':>15} {'Total MB':>12} {'GPU Alloc MB':>15} {'GPU Peak MB':>15}" ) separator = "-" * len(header) print(f"\nBenchmark results per {batch_size}-sample batch") print(separator) print(header) print(separator) for res in results: fwd_eval_str = f"{res['fwd_eval_mean']:.3f} ± {res['fwd_eval_std']:.3f}" fwd_str = f"{res['fwd_mean']:.3f} ± {res['fwd_std']:.3f}" bwd_str = f"{res['bwd_mean']:.3f} ± {res['bwd_std']:.3f}" total_mb = res["train_mem"] + res["non_train_mem"] gpu_alloc_mb = bytes_to_mb(res["gpu_mem"]) gpu_peak_mb = bytes_to_mb(res["gpu_peak"]) print( f"{res['name']:<{name_width}} {fwd_eval_str:>18} {fwd_str:>18} {bwd_str:>18} " f"{res['total_time']:>15.3f} {bytes_to_mb(res['train_mem']):>15.3f} " f"{bytes_to_mb(res['non_train_mem']):>15.3f} {bytes_to_mb(total_mb):>12.3f} " f"{gpu_alloc_mb:>15.3f} {gpu_peak_mb:>15.3f}" ) print(separator)