Source code for symm_learning.nn.pooling
# Created by Daniel Ordoñez (daniels.ordonez@gmail.com) at 12/02/25
from __future__ import annotations
import torch
from escnn.group import Representation
from symm_learning.linalg import irrep_radii
from symm_learning.representation_theory import direct_sum
[docs]
class IrrepSubspaceNormPooling(torch.nn.Module):
r"""Pool irrep features into :math:`\mathbb{G}`-invariant radii.
Given :math:`\mathbf{x}\in\mathcal{X}` with representation :math:`\rho_{\mathcal{X}}`, the module computes one
scalar per irreducible copy in the isotypic/irrep-spectral basis:
.. math::
r_{k,i} = \lVert \hat{\mathbf{x}}_{k,i} \rVert_2,
\qquad
\hat{\mathbf{x}}=\mathbf{Q}^T\mathbf{x}.
This is exactly :func:`~symm_learning.linalg.irrep_radii`, exposed as a module. The output transforms under a
direct sum of trivial representations, hence is invariant.
Args:
in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` describing how the input
last dimension transforms.
"""
def __init__(self, in_rep: Representation):
super().__init__()
G = in_rep.group
n_inv_features = len(in_rep.irreps)
self.in_rep = in_rep
self.out_rep = direct_sum([G.trivial_representation] * n_inv_features)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute one invariant radius per irreducible copy.
Args:
x (:class:`~torch.Tensor`): Input with trailing dimension ``in_rep.size``; any leading batch/time dims are
accepted.
Returns:
:class:`~torch.Tensor`: Tensor with same leading shape as ``x`` and last dim ``out_rep.size`` containing one
Euclidean norm per irrep block (trivial features).
"""
assert x.shape[-1] == self.in_rep.size, f"Expected input shape (..., {self.in_rep.size}), but got {x.shape}"
inv_features = irrep_radii(x, self.in_rep)
assert inv_features.shape[-1] == self.out_rep.size, f"{self.out_rep.size} != {inv_features.shape[-1]}"
return inv_features
def extra_repr(self) -> str: # noqa: D102
return f"Irrep Norm Pooling output={self.out_rep.size} {self.in_rep.group}-invariant features"