Source code for symm_learning.models.emlp

from __future__ import annotations

import logging
from math import ceil

import torch
from escnn.group import Representation

from symm_learning.nn.linear import eLinear
from symm_learning.nn.pooling import IrrepSubspaceNormPooling
from symm_learning.representation_theory import direct_sum

logger = logging.getLogger(__name__)


[docs] class eMLP(torch.nn.Module): r"""Equivariant MLP composed of :class:`~symm_learning.nn.linear.eLinear` layers. The network preserves the action of the underlying group on every layer by constructing hidden representations from the group regular representation (or a user-provided base representation) repeated as needed to reach the requested width. The network defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathcal{X} \to \mathcal{Y}. Functional equivariance constraint: .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x}) = \rho_{\mathcal{Y}}(g)\mathbf{f}_{\mathbf{\theta}}(\mathbf{x}) \quad \forall g\in\mathbb{G}. """ def __init__( self, in_rep: Representation, out_rep: Representation, hidden_units: list[int], activation: torch.nn.Module = torch.nn.ReLU(), dropout: float = 0.0, bias: bool = True, hidden_rep: Representation | None = None, init_scheme: str | None = "xavier_normal", ) -> None: r"""Create an equivariant MLP. Args: in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}` defining the group action on the input. out_rep (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\text{out}}`; must belong to the same group as ``in_rep``. hidden_units: Width of each hidden layer (number of representation copies). activation: Non-linearity inserted after every hidden layer. dropout: Dropout probability applied after activations; ``0.0`` disables it. bias: Whether to include a bias term in equivariant linear layers. hidden_rep (:class:`~escnn.group.Representation`, optional): Base representation used to build hidden layers. Defaults to the regular representation when ``None``. init_scheme: Parameter initialization scheme passed to :class:`~symm_learning.nn.linear.eLinear`. """ super().__init__() if len(hidden_units) == 0: raise ValueError("hidden_units must contain at least one layer") if in_rep.group != out_rep.group: raise ValueError("Input and output representations must belong to the same group") G = in_rep.group self.in_rep, self.out_rep = in_rep, out_rep assert isinstance(activation, torch.nn.Module), f"activation must be a torch.nn.Module, got {type(activation)}" drop_value = float(dropout) assert 0.0 <= drop_value <= 1.0, f"dropout must be within [0, 1], got {drop_value}" base_hidden_rep = hidden_rep or G.regular_representation assert base_hidden_rep.group == G, "hidden_rep must belong to the same group as in_rep" self.hidden_specs = [] layers: list[torch.nn.Module] = [] prev_rep = in_rep for idx, requested_dim in enumerate(hidden_units): target_rep = _hidden_representation(base_hidden_rep, requested_dim) linear = eLinear(prev_rep, target_rep, bias=bias, init_scheme=init_scheme) layers.append(linear) layers.append(activation) if drop_value > 0: layers.append(torch.nn.Dropout(drop_value)) prev_rep = target_rep self.hidden_specs.append(target_rep) layers.append(eLinear(prev_rep, out_rep, bias=bias, init_scheme=init_scheme)) self.net = torch.nn.Sequential(*layers) if init_scheme is not None: self.reset_parameters(init_scheme)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the equivariant MLP to ``x`` preserving group structure. Args: x: Tensor with trailing dimension matching ``in_rep.size``. Returns: Tensor with trailing dimension ``out_rep.size``. """ assert x.shape[-1] == self.in_rep.size, f"Expected (..., {self.in_rep.size}), got {x.shape}" return self.net(x)
[docs] @torch.no_grad() def reset_parameters(self, scheme: str = "xavier_normal") -> None: """Reinitialize all :class:`~symm_learning.nn.linear.eLinear` layers with the provided scheme.""" for module in self.net: if isinstance(module, eLinear): module.reset_parameters(scheme) logger.debug(f"Initialized eMLP with scheme '{scheme}'")
[docs] class iMLP(torch.nn.Module): r"""Invariant MLP built from an equivariant backbone and invariant pooling. The network defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathcal{X} \to \mathcal{Y}^{\text{inv}}. Functional invariance constraint: .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x}) = \mathbf{f}_{\mathbf{\theta}}(\mathbf{x}) \quad \forall g\in\mathbb{G}. """ def __init__( self, in_rep: Representation, out_dim: int, hidden_units: list[int], activation: torch.nn.Module = torch.nn.ReLU(), dropout: float = 0.0, bias: bool = True, hidden_rep: Representation | None = None, init_scheme: str | None = "xavier_normal", ): r"""Create a group-invariant MLP. The model first applies an equivariant MLP to extract group-aware features, pools them into the trivial representation, and finishes with an unconstrained linear head to produce invariant outputs. Args: in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}` defining the group action on the input. out_dim: Dimension of the invariant output vector. hidden_units: Width of each hidden layer in the equivariant backbone. activation: Non-linearity inserted after every hidden layer and after the backbone. dropout: Dropout probability applied after backbone activations. bias: Whether to include biases in the backbone and head. hidden_rep: Base representation used to build hidden layers. Defaults to the regular representation when ``None``. init_scheme: Parameter initialization scheme passed to :class:`~symm_learning.nn.linear.eLinear`. """ super().__init__() assert isinstance(hidden_units, list) and len(hidden_units) > 0, ( f"hidden_units must be a non-empty list, got {hidden_units}" ) self.in_rep = in_rep self.out_rep = direct_sum([in_rep.group.trivial_representation] * out_dim) G = in_rep.group # Build the equivariant feature extractor (eMLP) last_dim = hidden_units[-1] base_hidden_rep = hidden_rep or G.regular_representation out_rep = _hidden_representation(base_hidden_rep, last_dim) self.emlp_backbone = eMLP( in_rep=in_rep, out_rep=out_rep, hidden_units=hidden_units, activation=activation, dropout=dropout, bias=bias, hidden_rep=hidden_rep, init_scheme=None, ) # G-invariant pooling inv_pooling = IrrepSubspaceNormPooling(in_rep=out_rep) # Unconstrained head self.head = torch.nn.Linear( in_features=inv_pooling.out_rep.size, out_features=out_dim, bias=bias, ) # Network: [emlp -> activation -> inv pooling -> head] self.net = torch.nn.Sequential(*self.emlp_backbone.net, activation, inv_pooling, self.head) if init_scheme is not None: self.reset_parameters(init_scheme)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute invariant outputs from the input representation values.""" assert x.shape[-1] == self.in_rep.size, f"Expected (..., {self.in_rep.size}), got {x.shape}" return self.net(x)
[docs] @torch.no_grad() def reset_parameters(self, scheme: str = "xavier_normal") -> None: """Reinitialize all :class:`~symm_learning.nn.linear.eLinear` layers with the provided scheme.""" self.emlp_backbone.reset_parameters(scheme) # Initialize the unconstraine head if scheme == "xavier_normal": torch.nn.init.xavier_normal_(self.head.weight) elif scheme == "xavier_uniform": torch.nn.init.xavier_uniform_(self.head.weight) elif scheme == "kaiming_normal": torch.nn.init.kaiming_normal_(self.head.weight, nonlinearity="linear") elif scheme == "kaiming_uniform": torch.nn.init.kaiming_uniform_(self.head.weight, nonlinearity="linear") logger.debug(f"Initialized iMLP head with scheme '{scheme}'")
def _hidden_representation(base: Representation, target_dim: int) -> Representation: repeats = max(1, ceil(target_dim / base.size)) return direct_sum([base] * repeats)
[docs] class MLP(torch.nn.Module): r"""Standard baseline MLP with no symmetry constraints. The network defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathbb{R}^{d_{\mathrm{in}}} \to \mathbb{R}^{d_{\mathrm{out}}}. No equivariance or invariance constraints are imposed. """ def __init__( self, in_dim: int, out_dim: int, hidden_units: list[int], activation: torch.nn.Module | list[torch.nn.Module] = torch.nn.ReLU(), batch_norm: bool = False, bias: bool = True, ): """Constructor of a Multi-Layer Perceptron (MLP) model. Args: in_dim: Dimension of the input space. out_dim: Dimension of the output space. hidden_units: List of number of units in each hidden layer. activation: Activation module or list of activation modules. batch_norm: Whether to include batch normalization. bias: Whether to include a bias term in the linear layers. """ super().__init__() self.in_dim, self.out_dim = in_dim, out_dim assert hasattr(hidden_units, "__iter__") and hasattr(hidden_units, "__len__"), ( "hidden_units must be a list of integers" ) assert len(hidden_units) > 0, "A MLP with 0 hidden layers is equivalent to a linear layer" # Handle activation modules if isinstance(activation, list): assert len(activation) == len(hidden_units), ( "List of activation modules must have the same length as the number of hidden layers" ) activations = activation else: activations = [activation] * len(hidden_units) layers = [] dim_in = in_dim for units, act_module in zip(hidden_units, activations): layers.append(torch.nn.Linear(dim_in, units, bias=bias)) if batch_norm: layers.append(torch.nn.BatchNorm1d(units)) layers.append(act_module) dim_in = units # Head layer (output layer) layers.append(torch.nn.Linear(dim_in, out_dim, bias=bias)) self.net = torch.nn.Sequential(*layers)
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: """Forward pass of the MLP model.""" output = self.net(input) return output