Source code for symm_learning.models.imlp

# Created by Daniel Ordoñez (daniels.ordonez@gmail.com) at 12/02/25
from __future__ import annotations

from math import ceil

import escnn
import torch
from escnn.group import Representation
from escnn.nn import EquivariantModule, FieldType, GeometricTensor

from symm_learning.models.emlp import EMLP
from symm_learning.nn import IrrepSubspaceNormPooling


[docs] class IMLP(EquivariantModule): """G-Invariant Multi-Layer Perceptron. This module is a G-invariant MLP that extracts G-invariant features from the input tensor. The input tensor is first processed by an EMLP module that extracts G-equivariant features. The output of the EMLP module is then processed by an IrrepSubspaceNormPooling module that computes the norm of the features in each G-stable subspace associated to individual irreducible representations. The output of the IrrepSubspaceNormPooling module is a tensor with G-invariant features that can be processed with any NN architecture. Default implementation is to add a single linear layer projecting the invariant features to the desired output dimension. """ def __init__( self, in_type: FieldType, out_dim: int, # Number of G-invariant features to extract. hidden_units: list[int] = [128, 128, 128], activation: str = "ReLU", bias: bool = False, hidden_rep: Representation = None, ): super(IMLP, self).__init__() assert hasattr(hidden_units, "__iter__") and hasattr(hidden_units, "__len__"), ( "hidden_units must be a list of integers" ) assert len(hidden_units) > 0, "At least one equivariant layer is required" self.G = in_type.fibergroup self.in_type = in_type equiv_out_type = FieldType( gspace=in_type.gspace, representations=[self.G.regular_representation] * max(1, ceil(hidden_units[-1] / self.G.order())), ) self.equiv_feature_extractor = EMLP( in_type=in_type, out_type=equiv_out_type, hidden_units=hidden_units, activation=activation, bias=bias, hidden_rep=hidden_rep, ) self.inv_feature_extractor = IrrepSubspaceNormPooling(in_type=self.equiv_feature_extractor.out_type) self.head = torch.nn.Linear( in_features=self.inv_feature_extractor.out_type.size, out_features=out_dim, bias=bias ) self.out_type = FieldType(gspace=in_type.gspace, representations=[self.G.trivial_representation] * out_dim)
[docs] def forward(self, x: GeometricTensor) -> GeometricTensor: """Forward pass of the G-invariant MLP.""" z = self.equiv_feature_extractor(x) # Compute the equivariant features using the EMLP module. z_inv = self.inv_feature_extractor(z) # Get G-invariant features out = self.head(z_inv.tensor) # Unconstrained head linear layer. return self.out_type(out)
[docs] def evaluate_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: # noqa: D102 return input_shape[:-1] + (len(self.out_type.size),)
[docs] def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-4) -> list[tuple[any, float]]: # noqa: D102 self.equiv_feature_extractor.check_equivariance(atol=atol, rtol=rtol) self.inv_feature_extractor.check_equivariance(atol=atol, rtol=rtol) return super(IMLP, self).check_equivariance(atol=atol, rtol=rtol)
[docs] def export(self): """Exporting to a torch.nn.Sequential""" imlp: torch.nn.Sequential = escnn.nn.SequentialModule( self.equiv_feature_extractor, self.inv_feature_extractor, ).export() print(self.head) imlp.add_module("head", self.head) imlp.eval() return imlp