Source code for symm_learning.nn.disentangled
# Created by Daniel Ordoñez (daniels.ordonez@gmail.com) at 12/02/25
from __future__ import annotations
import torch
from escnn.group import Representation
from symm_learning.representation_theory import direct_sum, isotypic_decomp_rep
[docs]
class Change2DisentangledBasis(torch.nn.Module):
r"""Map features to the isotypic/irrep-spectral basis.
For :math:`\mathbf{x}\in\mathcal{X}` with representation :math:`\rho_{\mathcal{X}}`, this module applies
:math:`\mathbf{Q}^{-1}` from :func:`~symm_learning.representation_theory.isotypic_decomp_rep`:
.. math::
\hat{\mathbf{x}} = \mathbf{Q}^{-1}\mathbf{x},
\qquad
\rho_{\mathcal{X}} = \mathbf{Q}\left(
\bigoplus_{k\in[1,n_{\text{iso}}]}
\bigoplus_{i\in[1,n_k]}
\hat{\rho}_k
\right)\mathbf{Q}^T.
Hence, coordinates in ``out_rep`` are grouped by isotypic subspace (same irrep type contiguous).
The map is linear and :math:`\mathbb{G}`-equivariant:
.. math::
\hat{\rho}_{\mathcal{X}}(g)\,\hat{\mathbf{x}}
= \mathbf{Q}^{-1}\rho_{\mathcal{X}}(g)\mathbf{x},
\quad
\hat{\rho}_{\mathcal{X}}(g) = \mathbf{Q}^{-1}\rho_{\mathcal{X}}(g)\mathbf{Q}.
Args:
in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` describing the input
feature space.
learnable (:class:`bool`, optional): If ``True``, the change-of-basis matrix is a trainable parameter.
Defaults to ``False``.
"""
def __init__(self, in_rep: Representation, learnable: bool = False):
super().__init__()
in_rep_iso_basis = isotypic_decomp_rep(in_rep)
iso_subspaces_reps = in_rep_iso_basis.attributes["isotypic_reps"]
self.in_rep, self.out_rep = in_rep, direct_sum(list(iso_subspaces_reps.values()))
Qin2iso = torch.as_tensor(in_rep_iso_basis.change_of_basis_inv, dtype=torch.get_default_dtype())
identity = torch.eye(Qin2iso.shape[-1], device=Qin2iso.device, dtype=Qin2iso.dtype)
self._is_in_iso_basis = torch.allclose(Qin2iso, identity, atol=1e-5, rtol=1e-5)
self._learnable = learnable
if self._learnable:
self.Qin2iso = torch.nn.Linear(in_features=self.in_rep.size, out_features=self.out_rep.size, bias=False)
with torch.no_grad():
self.Qin2iso.weight.copy_(Qin2iso)
else:
self.register_buffer("Qin2iso", Qin2iso)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the basis change to isotypic coordinates.
Args:
x (:class:`~torch.Tensor`): Input whose last dimension equals ``in_rep.size``; arbitrary leading dimensions
allowed.
Returns:
:class:`~torch.Tensor`: Tensor with the same leading shape and last dimension ``out_rep.size``
(same as ``in_rep``), expressed in the isotypic basis. If the input is already in that basis,
the tensor is returned unchanged.
"""
assert x.shape[-1] == self.in_rep.size, f"Expected input shape (..., {self.in_rep.size}), got {x.shape}"
if self._is_in_iso_basis:
return x
if self._learnable:
return self.Qin2iso(x)
Q = self.Qin2iso.to(dtype=x.dtype, device=x.device)
return torch.matmul(x, Q.t())
def extra_repr(self) -> str: # noqa: D102
return f"learnable: {self._learnable}"