from __future__ import annotations
import logging
from typing import Literal
import torch
import torch.nn.functional as F
from escnn.group import Representation
from torch.nn.utils import parametrize
from symm_learning.nn.parametrizations import CommutingConstraint, InvariantConstraint
from symm_learning.representation_theory import GroupHomomorphismBasis
from symm_learning.utils import get_spectral_trivial_mask
logger = logging.getLogger(__name__)
eINIT_SCHEMES = Literal["xavier_normal", "xavier_uniform", "kaiming_normal", "kaiming_uniform"]
[docs]
def impose_linear_equivariance(
lin: torch.nn.Linear,
in_rep: Representation,
out_rep: Representation,
basis_expansion_scheme: str = "isotypic_expansion",
) -> None:
r"""Impose equivariance constraints on a given torch.nn.Linear layer using torch parametrizations.
Impose via torch parametrizations (hard constraints on trainable parameters ) that the weight matrix of
the given linear layer commutes with the group actions of the input and output representations. That is:
.. math::
\rho_{\text{out}}(g) W = W \rho_{\text{in}}(g) \quad \forall g \in G
If the layer has a bias term, it is constrained to be invariant:
.. math::
\rho_{\text{out}}(g) b = b \quad \forall g \in G
Parameters
----------
lin : :class:`~torch.nn.Module`
The linear layer to impose equivariance on. Must have 'weight' and optionally 'bias' attributes.
in_rep : :class:`~escnn.group.Representation`
The input representation :math:`\rho_{\text{in}}` of the layer.
out_rep : :class:`~escnn.group.Representation`
The output representation :math:`\rho_{\text{out}}` of the layer.
basis_expansion_scheme : str
Basis expansion strategy for the commuting constraint (``"memory_heavy"`` or ``"isotypic_expansion"``).
"""
assert isinstance(lin, torch.nn.Module), f"lin must be a torch.nn.Module, got {type(lin)}"
# Add attributes to the layer for later reference
lin.in_rep = in_rep
lin.out_rep = out_rep
# Register parametrizations enforcing equivariance
parametrize.register_parametrization(
lin, "weight", CommutingConstraint(in_rep, out_rep, basis_expansion=basis_expansion_scheme)
)
if lin.bias is not None:
parametrize.register_parametrization(lin, "bias", InvariantConstraint(out_rep))
[docs]
class eLinear(torch.nn.Linear):
r"""Parameterize a :math:`\mathbb{G}`-equivariant linear map with optional invariant bias.
The layer learns coefficients over :math:`\operatorname{Hom}_\mathbb{G}(\rho_{\text{in}}, \rho_{\text{out}})`,
synthesizing a dense weight matrix :math:`\mathbf{W}` satisfying:
.. math::
\rho_{\text{out}}(g) \mathbf{W} = \mathbf{W} \rho_{\text{in}}(g) \quad \forall g \in \mathbb{G}
If ``bias=True``, the bias vector :math:`\mathbf{b}` is constrained to the invariant subspace:
.. math::
\rho_{\text{out}}(g) \mathbf{b} = \mathbf{b} \quad \forall g \in \mathbb{G}
Note:
Runtime behavior depends on mode.
In training mode (``model.train()``), the constrained dense tensors are recomputed every forward pass, which
is correct for gradient updates but slower.
In inference mode (``model.eval()``), the expanded dense weight (and optional invariant bias) are cached and
reused until parameters change or :meth:`invalidate_cache` is called, which is faster.
With the cache active, :meth:`forward` is computationally equivalent to a symmetry-agnostic
:class:`~torch.nn.Linear` with fixed dense ``weight`` and ``bias``.
Attributes:
homo_basis (:class:`~symm_learning.representation_theory.GroupHomomorphismBasis`): Handler exposing the
equivariant basis and metadata.
bias_module (:class:`~symm_learning.nn.linear.InvariantBias` | None): Optional module handling the invariant
bias.
"""
def __init__(
self,
in_rep: Representation,
out_rep: Representation,
bias: bool = True,
init_scheme: str | None = "xavier_normal",
basis_expansion_scheme: str = "isotypic_expansion",
):
r"""Initialize the equivariant layer.
Args:
in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` describing how inputs
transform.
out_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{out}}` describing how
outputs transform.
bias (:class:`bool`, optional): Enables the invariant bias if the trivial irrep is present in ``out_rep``.
Default: ``True``.
init_scheme (:class:`str` | :class:`None`, optional): Initialization method passed to
:meth:`~symm_learning.representation_theory.GroupHomomorphismBasis.initialize_params`. Use ``None``
to skip initialization. Default: ``"xavier_normal"``.
basis_expansion_scheme (:class:`str`, optional): Strategy for materializing the basis
(``"isotypic_expansion"`` or ``"memory_heavy"``). Default: ``"isotypic_expansion"``.
Raises:
ValueError: If :math:`\dim(\mathrm{Hom}_{\mathbb{G}}(\rho_{\text{in}}, \rho_{\text{out}})) = 0`.
"""
super().__init__(in_features=in_rep.size, out_features=out_rep.size, bias=bias)
# Delete linear unconstrained module parameters
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.register_buffer("_weight", None, persistent=False)
self._weight_cache_dirty = True
# Instanciate the handler of the basis of Hom_G(in_rep, out_rep)
self.homo_basis = GroupHomomorphismBasis(in_rep, out_rep, basis_expansion=basis_expansion_scheme)
self.in_rep, self.out_rep = self.homo_basis.in_rep, self.homo_basis.out_rep
if self.homo_basis.dim == 0:
raise ValueError(
f"No equivariant linear maps exist between {in_rep} and {out_rep}.\n dim(Hom_G(in_rep, out_rep))=0"
)
# Register weight parameters (degrees of freedom: dof) and buffers
self.register_parameter(
"weight_dof",
torch.nn.Parameter(torch.zeros(self.homo_basis.dim, dtype=torch.get_default_dtype()), requires_grad=True),
)
self.bias_module = InvariantBias(out_rep) if bias else None
# Register backward hook to flag caches stale/invalid whenever grads are produced.
self.weight_dof.register_hook(self._mark_weight_cache_dirty)
if init_scheme is not None:
self.reset_parameters(scheme=init_scheme)
[docs]
def expand_weight(self):
r"""Return the dense equivariant weight, caching it outside training.
Returns:
torch.Tensor: Dense matrix of shape ``(out_rep.size, in_rep.size)``.
"""
W = self.homo_basis(self.weight_dof) # Recompute linear map
self._weight = W
self._weight_cache_dirty = False
return W
@property
def weight(self) -> torch.Tensor:
"""Dense equivariant weight; recomputed in train, cached in eval."""
if self.training or self._weight is None or self._weight_cache_dirty:
return self.expand_weight()
return self._weight
@property
def bias(self) -> torch.Tensor | None:
"""Invariant bias from :class:`InvariantBias` (``None`` if disabled)."""
return self.bias_module.bias if self.bias_module is not None else None
[docs]
@torch.no_grad()
def reset_parameters(self, scheme="xavier_normal"):
"""Reset all trainable parameters.
Args:
scheme (:class:`str`): Initialization scheme (``"xavier_normal"``, ``"xavier_uniform"``,
``"kaiming_normal"``, or ``"kaiming_uniform"``).
"""
if not hasattr(self, "homo_basis"): # First call on torch.nn.Linear init
return super().reset_parameters()
new_params = self.homo_basis.initialize_params(scheme)
self.weight_dof.copy_(new_params)
# Update cache
self.expand_weight()
if self.bias_module is not None:
self.bias_module.reset_parameters(scheme=scheme)
logger.debug(f"Reset parameters of linear layer to {scheme}")
def _mark_weight_cache_dirty(self, grad: torch.Tensor) -> torch.Tensor:
self._weight_cache_dirty = True
return grad
[docs]
def invalidate_cache(self) -> None:
"""Clear cached expansions and mark them stale."""
self._weight = None
self._weight_cache_dirty = True
if self.bias_module is not None:
self.bias_module.invalidate_cache()
def _refresh_eval_cache(self) -> None:
"""Ensure eval-mode caches are materialized."""
if self._weight is None or self._weight_cache_dirty:
self.expand_weight()
if self.bias_module is not None:
self.bias_module.refresh_eval_cache()
def _apply(self, fn):
super()._apply(fn)
self.invalidate_cache()
return self
[docs]
def train(self, mode: bool = True): # noqa: D102
"""Switch mode and keep cached expanded tensors consistent."""
result = super().train(mode)
if mode: # Switching to train mode - invalidate cache
self.invalidate_cache()
else: # Switching to eval mode - refresh cache
self._refresh_eval_cache()
return result
[docs]
def load_state_dict(self, state_dict, strict: bool = True): # noqa: D102
"""Load parameters and invalidate cached expanded tensors."""
result = super().load_state_dict(state_dict, strict)
self.invalidate_cache()
return result
[docs]
class InvariantBias(torch.nn.Module):
r"""Module parameterizing a learnable :math:`\mathbb{G}`-invariant bias.
For representation space :math:`\mathcal{X}`, this module enforces
:math:`\rho_{\mathcal{X}}(g)\mathbf{b}=\mathbf{b}` for all :math:`g\in\mathbb{G}`. Hence only trivial-irrep
coordinates in the irrep-spectral basis carry free parameters.
If the input representation does not contain the trivial irrep (no trivial/invariant subspace), the module behaves
as the identity function.
Note:
Runtime behavior depends on mode.
In training mode (``model.train()``), the invariant bias is recomputed each forward pass.
In inference mode (``model.eval()``), the expanded invariant bias is cached and reused until ``bias_dof``
changes or :meth:`invalidate_cache` is called, which is faster.
With the cache active, the forward path is the same computation as the standard symmetry-agnostic bias add
``input + b`` with fixed ``b``.
Attributes:
in_rep (:class:`~escnn.group.Representation`): Representation defining the symmetry action on
:math:`\mathcal{X}`.
out_rep (:class:`~escnn.group.Representation`): Same as ``in_rep`` (bias acts in the same space).
has_bias (:class:`bool`): ``True`` iff the trivial irrep is present and a learnable invariant bias exists.
bias_dof (:class:`~torch.nn.Parameter`): Learnable trivial-subspace coefficients (present only if
``has_bias=True``).
"""
def __init__(self, in_rep: Representation):
r"""Construct the invariant bias module.
Args:
in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` of the input space
(same as output space).
"""
super().__init__()
self.in_rep, self.out_rep = in_rep, in_rep
G = self.in_rep.group
trivial_id = G.trivial_representation.id
# Assert invariant vector is possible.
self.has_bias = in_rep._irreps_multiplicities.get(trivial_id, 0) > 0
self.register_buffer("_bias", None, persistent=False)
self._bias_cache_dirty = False
if not self.has_bias: # No bias -> No buffer memory consumption
return
dtype = torch.get_default_dtype()
# Number of bias trainable parameters are equal to the output multiplicity of the trivial irrep
m_out_trivial = self.in_rep._irreps_multiplicities[trivial_id]
self.register_parameter("bias_dof", torch.nn.Parameter(torch.zeros(m_out_trivial), requires_grad=True))
self.register_buffer("Qout", torch.tensor(self.in_rep.change_of_basis, dtype=dtype))
# Save mask of trivial dimensions in the irrep-spectral basis
self.register_buffer("spectral_trivial_mask", get_spectral_trivial_mask(self.in_rep))
# Cache reference of last computed bias
self._bias_cache_dirty = True
self.bias_dof.register_hook(self._mark_bias_cache_dirty)
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Apply the invariant bias.
Args:
input (:class:`~torch.Tensor`): Tensor whose last dimension equals ``in_rep.size``.
Returns:
:class:`~torch.Tensor`: Output tensor with the same shape as ``input``.
"""
if not self.has_bias:
return input
return input + self.bias
@property
def bias(self):
"""Invariant bias; recomputed in training, cached otherwise."""
if not self.has_bias:
return None
# If training, recompute bias; else use cached version
if self.training or self._bias is None or self._bias_cache_dirty:
return self.expand_bias()
return self._bias
[docs]
def expand_bias(self):
"""Expand the learnable parameters into the invariant bias in the original basis."""
bias = torch.mv(self.Qout[:, self.spectral_trivial_mask], self.bias_dof)
# Update cache
self._bias = bias
self._bias_cache_dirty = False
return bias
[docs]
def expand_bias_spectral_basis(self):
"""Return the invariant bias expressed in the irrep-spectral basis."""
spectral_bias = torch.zeros(self.in_rep.size, dtype=self.bias_dof.dtype, device=self.bias_dof.device)
spectral_bias[self.spectral_trivial_mask] = self.bias_dof
return spectral_bias
[docs]
def reset_parameters(self, scheme="zeros"):
"""Initialize the invariant bias degrees of freedom."""
if not self.has_bias:
return
if scheme == "zeros":
torch.nn.init.zeros_(self.bias_dof)
trivial_id = self.in_rep.group.trivial_representation.id
m = self.in_rep._irreps_multiplicities[trivial_id]
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(torch.empty(m, m))
bound = 1 / torch.math.sqrt(fan_in) if fan_in > 0 else 0
torch.nn.init.uniform_(self.bias_dof, -bound, bound)
self.invalidate_cache()
[docs]
def train(self, mode: bool = True):
"""Switch between training and evaluation modes, managing cache appropriately."""
result = super().train(mode)
if mode: # Switching to train mode - invalidate cache
self.invalidate_cache()
else: # Switching to eval mode - refresh cache
self.refresh_eval_cache()
return result
def _mark_bias_cache_dirty(self, grad: torch.Tensor) -> torch.Tensor:
self._bias_cache_dirty = True
return grad
[docs]
def invalidate_cache(self) -> None:
"""Clear cached bias so it is recomputed on next use."""
if not self.has_bias:
return
self._bias = None
self._bias_cache_dirty = True
[docs]
def refresh_eval_cache(self) -> None:
"""Ensure eval-mode cache is populated."""
if not self.has_bias:
return
if self._bias is None or self._bias_cache_dirty:
self.expand_bias()
def _apply(self, fn):
super()._apply(fn)
self.invalidate_cache()
return self
[docs]
def load_state_dict(self, state_dict, strict: bool = True): # noqa: D102
"""Load parameters and invalidate cached expanded bias."""
result = super().load_state_dict(state_dict, strict)
self.invalidate_cache()
return result
[docs]
class eAffine(torch.nn.Module):
r"""Equivariant affine map with per-irrep scales and invariant bias.
Let :math:`\mathbf{x}\in\mathcal{X}` with representation
.. math::
\rho_{\mathcal{X}} = \mathbf{Q}\left(
\bigoplus_{k\in[1,n_{\text{iso}}]}
\bigoplus_{i\in[1,n_k]}
\hat{\rho}_k
\right)\mathbf{Q}^T.
This module applies
.. math::
\mathbf{y} = \mathbf{Q}\,\mathbf{D}_{\alpha}\,\mathbf{Q}^T\mathbf{x} + \mathbf{b},
where :math:`\mathbf{D}_{\alpha}` is diagonal in irrep-spectral basis and constant over dimensions of each irrep
copy (:math:`\alpha_{k,i}`), while :math:`\mathbf{b}\in\mathrm{Fix}(\rho_{\mathcal{X}})` (trivial block only).
Therefore:
.. math::
\rho_{\mathcal{X}}(g)\mathbf{y}
= \operatorname{eAffine}\!\left(\rho_{\mathcal{X}}(g)\mathbf{x}\right)
\quad \forall g\in\mathbb{G}.
When ``learnable=True`` these DoFs are trainable parameters. When ``learnable=False``,
``scale_dof`` and ``bias_dof`` are provided at call-time (FiLM style).
Args:
in_rep: :class:`~escnn.group.Representation` describing the input/output space :math:`\rho_{\text{in}}`.
bias: include invariant biases when the trivial irrep is present. Default: ``True``.
learnable: if ``False``, no parameters are registered and ``scale_dof``/``bias_dof`` must
be passed at call time. Default: ``True``.
init_scheme: initialization for the learnable DoFs (``"identity"`` or ``"random"``). Set
to ``None`` to skip init (e.g. when loading weights). Ignored when ``learnable=False``.
Shape:
- Input: ``(..., D)`` with ``D = in_rep.size``.
- ``scale_dof`` (optional): ``(..., num_scale_dof)`` where ``n_irreps`` is the number of irreps in ``in_rep``.
- ``bias_dof`` (optional): ``(..., num_bias_dof)`` when ``bias=True``.
- Output: ``(..., D)``.
Note:
Runtime behavior depends on mode.
In training mode (``model.train()``), the affine map is recomputed each forward pass.
In inference mode (``model.eval()``) and with ``learnable=True``, the dense affine map
:math:`\mathbf{Q}\mathbf{D}_{\alpha}\mathbf{Q}^T` (and optional invariant bias) is cached and reused until
parameters change or :meth:`invalidate_cache` is called, which is faster.
Unlike :class:`~symm_learning.nn.linear.eLinear`, this module is not a strict symmetry-agnostic drop-in
affine block, because parameters are irrep-structured and may also be provided externally
(FiLM-style via ``scale_dof``/``bias_dof``).
Attributes:
rep_x (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{X}}` of the feature space.
Q (:class:`~torch.Tensor`): Change-of-basis matrix to the irrep-spectral basis.
Q_inv (:class:`~torch.Tensor`): Inverse change-of-basis matrix from irrep-spectral basis.
bias_module (:class:`~symm_learning.nn.linear.InvariantBias` | None): Optional module handling the invariant
bias.
"""
def __init__(
self,
in_rep: Representation,
bias: bool = True,
learnable: bool = True,
init_scheme: Literal["identity", "random"] | None = "identity",
):
super().__init__()
self.in_rep, self.out_rep = in_rep, in_rep
self.learnable = learnable
self.rep_x = in_rep
G = self.rep_x.group
dtype = torch.get_default_dtype()
# Common metadata --------------------------------------------------------------------------
self.register_buffer("Q", torch.tensor(self.rep_x.change_of_basis, dtype=dtype))
self.register_buffer("Q_inv", torch.tensor(self.rep_x.change_of_basis_inv, dtype=dtype))
# Buffers needed to map per-irrep scale to full spectral scale
irrep_dims_list = [G.irrep(*irrep_id).size for irrep_id in self.rep_x.irreps]
irrep_dims = torch.tensor(irrep_dims_list, dtype=torch.long)
self._num_scale_dof = len(irrep_dims_list)
self.register_buffer(
"irrep_indices", torch.repeat_interleave(torch.arange(len(irrep_dims), dtype=torch.long), irrep_dims)
)
trivial_id = G.trivial_representation.id
self.has_bias = bias and trivial_id in self.rep_x._irreps_multiplicities
self._num_bias_dof = 0
self.bias_module = None
if self.has_bias and self.learnable:
# Reuse invariant-bias helper (stores bias_dof and spectral mask)
self.bias_module = InvariantBias(self.rep_x)
self._num_bias_dof = self.bias_module.bias_dof.numel()
# Convenience handle so callers expecting ``bias_dof`` still find it.
self.bias_dof = self.bias_module.bias_dof
elif self.has_bias:
trivial_mask = torch.zeros(self.rep_x.size, dtype=torch.bool)
offset = 0
for irrep_id in self.rep_x.irreps:
if irrep_id == trivial_id:
trivial_mask[offset] = 1
self._num_bias_dof += 1
offset += G.irrep(*irrep_id).size
self.register_buffer("trivial_subspace_mask", trivial_mask)
# Mode-specific parameters -----------------------------------------------------------------
if self.learnable:
self.register_parameter("scale_dof", torch.nn.Parameter(torch.ones(self.num_scale_dof, dtype=dtype)))
if self.has_bias and self.bias_module is None:
self.register_parameter("bias_dof", torch.nn.Parameter(torch.zeros(self.num_bias_dof, dtype=dtype)))
elif self.has_bias:
# bias handled by bias_module; keep attribute for API compatibility
self.bias_dof = self.bias_module.bias_dof
else:
self.register_parameter("bias_dof", None)
self.register_buffer("_affine", None, persistent=False)
self._affine_cache_dirty = True
if self.learnable:
self.scale_dof.register_hook(self._mark_affine_cache_dirty)
if init_scheme is not None:
self.reset_parameters(scheme=init_scheme)
[docs]
def forward(
self,
input: torch.Tensor,
scale_dof: torch.Tensor | None = None,
bias_dof: torch.Tensor | None = None,
) -> torch.Tensor:
r"""Apply the equivariant affine transform.
When ``learnable=False`` ``scale_dof`` (and ``bias_dof`` if ``bias=True``) must be provided.
Args:
input: Tensor :math:`\mathbf{x}` whose last dimension matches ``in_rep.size``.
scale_dof: Optional per-irrep scaling DoFs (length ``num_scale_dof``). Required when
``learnable=False`` with leading dims matching ``input.shape[:-1]``.
bias_dof: Optional invariant-bias DoFs (length ``num_bias_dof``). Required
when ``learnable=False`` and ``bias=True`` with leading dims matching ``input.shape[:-1]``.
"""
if input.shape[-1] != self.rep_x.size:
raise ValueError(f"Expected last dimension {self.rep_x.size}, got {input.shape[-1]}")
# Obtain per-dimension spectral scale; reuse learnable bias directly in original basis.
if self.learnable:
bias_orig = self.bias_module.bias if self.has_bias and self.bias_module is not None else None
if not self.training:
affine = self._affine
if affine is None or self._affine_cache_dirty:
affine = self.expand_affine()
y = torch.einsum("ij,...j->...i", affine, input)
if bias_orig is not None:
y = y + bias_orig
return y # Use cached matrix.
scale_spec = self.scale_dof[self.irrep_indices] # (D,)
else:
scale_spec, spectral_bias = self.broadcast_spectral_scale_and_bias(
scale_dof, bias_dof, input_shape=input.shape
)
bias_orig = None
if spectral_bias is not None:
# Map spectral bias back to original basis: (..., D)
bias_orig = torch.einsum("ij,...j->...i", self.Q, spectral_bias)
# Apply scaling in original basis via Q * diag(scale_spec) * Q_inv. Output shape matches input (..., D).
y = torch.einsum("ij,...j,jk,...k->...i", self.Q, scale_spec, self.Q_inv, input)
if bias_orig is not None:
y = y + bias_orig
return y
[docs]
def expand_affine(self) -> torch.Tensor:
"""Expand the per-irrep scales into an affine matrix in the original basis."""
scale_spec = self.scale_dof[self.irrep_indices]
affine = (self.Q * scale_spec) @ self.Q_inv
self._affine = affine
self._affine_cache_dirty = False
return affine
[docs]
def broadcast_spectral_scale_and_bias(
self,
scale_dof: torch.Tensor,
bias_dof: torch.Tensor | None = None,
input_shape: torch.Size | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Return spectral scale and bias from provided DoFs.
Args:
scale_dof: Per-irrep scale coefficients shaped ``(..., num_scale_dof)``.
bias_dof: Invariant-bias coefficients shaped ``(..., num_bias_dof)`` when ``bias=True``.
input_shape: Shape of the input tensor when ``learnable=False``. Used to validate DoF shapes.
Returns:
tuple[torch.Tensor, torch.Tensor | None]: Pair ``(spectral_scale, spectral_bias)`` where
``spectral_scale`` is shaped ``(..., rep_x.size)`` and ``spectral_bias`` is shaped
``(..., rep_x.size)`` (or ``None`` when ``bias=False``).
"""
input_shape = () if input_shape is None else input_shape[:-1]
if scale_dof is None or scale_dof.shape != (*input_shape, self._num_scale_dof):
raise ValueError(
f"Expected scale_dof shape {(*input_shape, self._num_scale_dof)}, got "
f"{scale_dof.shape if scale_dof is not None else None}"
)
# Broadcast scale per irrep subspace to each irrep subspace dimension.
spectral_scale = scale_dof[..., self.irrep_indices]
spectral_bias = None
if self.has_bias:
# Use provided bias DoFs when passed (external control), otherwise fall back to learnable helper.
if bias_dof is None and self.bias_module is not None:
bias_dof = self.bias_module.bias_dof
if bias_dof is None or bias_dof.shape != (*input_shape, self._num_bias_dof):
raise ValueError(
f"Expected bias_dof shape {(*input_shape, self._num_bias_dof)}, got "
f"{bias_dof.shape if bias_dof is not None else None}"
)
if self.bias_module is not None:
# Learnable bias: use helper to expand into spectral basis
spectral_bias = self.bias_module.expand_bias_spectral_basis()
else:
spectral_bias = bias_dof.new_zeros(*bias_dof.shape[:-1], self.rep_x.size)
spectral_bias[..., self.trivial_subspace_mask] = bias_dof
return spectral_scale, spectral_bias
[docs]
def reset_parameters(self, scheme: Literal["identity", "random"] = "identity") -> None:
"""Initialize spectral scale/bias DoFs.
Args:
scheme: ``"identity"`` sets all scales to one and bias to zero; ``"random"`` samples both
uniformly in ``[-1, 1]``. Set to ``None`` when loading checkpoints to skip reinit.
"""
if not self.learnable:
return
if scheme == "identity":
torch.nn.init.ones_(self.scale_dof)
if self.has_bias and self.bias_module is not None and self.bias_module.bias_dof is not None:
torch.nn.init.zeros_(self.bias_module.bias_dof)
elif self.has_bias and self.bias_dof is not None:
torch.nn.init.zeros_(self.bias_dof)
elif scheme == "random":
torch.nn.init.uniform_(self.scale_dof, -1, 1)
if self.has_bias and self.bias_module is not None and self.bias_module.bias_dof is not None:
torch.nn.init.uniform_(self.bias_module.bias_dof, -1, 1)
elif self.has_bias and self.bias_dof is not None:
torch.nn.init.uniform_(self.bias_dof, -1, 1)
else:
raise NotImplementedError(f"Init scheme {scheme} not implemented")
self.invalidate_cache()
def extra_repr(self) -> str: # noqa: D102
return f"bias={self.has_bias} learnable={self.learnable} \nin_rep={self.in_rep}"
def _mark_affine_cache_dirty(self, grad: torch.Tensor) -> torch.Tensor:
self._affine_cache_dirty = True
return grad
[docs]
def invalidate_cache(self) -> None:
"""Clear cached affine map so it is recomputed on next use."""
self._affine = None
self._affine_cache_dirty = True
if self.bias_module is not None:
self.bias_module.invalidate_cache()
def _refresh_eval_cache(self) -> None:
if not self.learnable:
return
if self._affine is None or self._affine_cache_dirty:
self.expand_affine()
if self.bias_module is not None:
self.bias_module.refresh_eval_cache()
def _apply(self, fn):
super()._apply(fn)
self.invalidate_cache()
return self
[docs]
def train(self, mode: bool = True): # noqa: D102
"""Switch mode and keep cached affine expansion consistent."""
result = super().train(mode)
if mode: # Switching to train mode - invalidate cache
self.invalidate_cache()
else: # Switching to eval mode - refresh cache
self._refresh_eval_cache()
return result
[docs]
def load_state_dict(self, state_dict, strict: bool = True): # noqa: D102
"""Load parameters and invalidate cached affine expansion."""
result = super().load_state_dict(state_dict, strict)
self.invalidate_cache()
return result
@property
def num_scale_dof(self) -> int:
"""Number of per-irrep scaling degrees of freedom (length of ``scale_dof``)."""
return self._num_scale_dof
@property
def num_bias_dof(self) -> int:
"""Number of bias degrees of freedom (length of ``bias_dof``) for invariant irreps."""
return self._num_bias_dof