Source code for symm_learning.nn.distributions

from __future__ import annotations

import escnn
import torch
import torch.nn.functional as F
from escnn.group import directsum
from escnn.nn import EquivariantModule, FieldType


def _equiv_mean_var_from_input(
    input: torch.Tensor,
    idx: torch.Tensor,
    Q2_T: torch.Tensor,
    dim_y: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Extract mean and variance from the input tensor."""
    mu = input[..., :dim_y]  # (B, n)
    log_eigvals = input[..., dim_y:]  # (B, n_irreps)
    var_irrep_spectral_basis = torch.exp(log_eigvals[..., idx]) + 1e-6  # (B, n)
    var = var_irrep_spectral_basis @ Q2_T  # (B, n)
    return mu, var


[docs] class EquivMultivariateNormal(EquivariantModule): r"""G-equivariant multivariate normal. Utility layer to parameterize a G-equivariant multivariate gaussian/normal distribution. .. math:: y \sim \mathcal{N} \bigl( \mu(x), \Sigma(x) \bigr), Where x is the input to the layer parameterizing the mean and (free) degrees of freedom of the covariance matrix, constrained to satisfy: .. math:: \rho_Y(g) \mu(x) = \mu(\rho_X(g) \cdot x) \rho_Y(g) \Sigma(x) \rho_Y(g)^{\top}= \Sigma(\rho_X(g) x) \quad \forall g \in G. Such that: .. math:: P(y \mid x) = P(\rho_Y(g) y \mid \rho_X(g) x) \quad \forall g \in G. The input of the layer is composed of the desired mean of the distribution and the log-variances of each irreducible subspace of the representation :math:`\rho_Y` of the output. The number of log-variances varies with the number of irreducible subspaces of the representation, hence this layer is meant to be instantiated before the `EquivaraintModule` that will be used to parameterize the multivariate normal distribution. See the example below. Parameters ---------- y_type : FieldType Field/feature type of *X* (the mean). diagonal : bool, default ``True`` Only diagonal covariance matrices are implemented. Note these are not necessarily constant multiples of the identity. Example: --------- >>> from escnn.group import CyclicGroup >>> from symm_learning.models.emlp import EMLP >>> G = CyclicGroup(3) >>> x_type = FieldType(escnn.gspaces.no_base_space(G), representations=[G.regular_representation]) >>> y_type = FieldType(escnn.gspaces.no_base_space(G), representations=[G.regular_representation] * 1) >>> e_normal = EquivMultivariateNormal(y_type, diagonal=True) >>> nn = EMLP(in_type=x_type, out_type=e_normal.in_type) >>> x = torch.randn(1, x_type.size) >>> dist = e_normal.get_distribution(nn(x_type(x))) >>> # Sample from the distribution >>> y = dist.sample() """ def __init__(self, y_type: FieldType, diagonal=True): super().__init__() self.y_type = y_type self.diagonal = diagonal rep_y = y_type.representation G = rep_y.group if not diagonal: raise NotImplementedError("Full covariance matrices are not implemented yet.") # ----- irrep metadata ------------------------------------------------ self.irrep_dims = torch.tensor([G.irrep(*irr).size for irr in rep_y.irreps], dtype=torch.long) # index vector that broadcasts irrep-scalars to component level idx = [i for i, d in enumerate(self.irrep_dims) for _ in range(d)] self.register_buffer("idx", torch.tensor(idx, dtype=torch.long)) self.n_cov_params = len(rep_y.irreps) # Number of params for the covariance matrix # ----- change-of-basis (irrep_spectral → user) ----------------------------- Q = torch.tensor(rep_y.change_of_basis, dtype=torch.get_default_dtype()) self.register_buffer("Q2_T", (Q.pow(2)).t()) # (n, n) transposed # ----- Group action on the degrees of freedom of the Cov matrix ------------ self.rep_cov_dof = directsum([G.trivial_representation] * len(rep_y.irreps)) self.in_type = FieldType(y_type.gspace, [rep_y, self.rep_cov_dof]) self.out_type = self.y_type # Not used.
[docs] def forward(self, input): """Compute the mean and variance of a equivariant multivariate normal distribution Args: input (FieldType): Input tensor of shape (B, n + n_irreps) where: `B` is the batch size, `n` is the size of the output type (mean), `n_irreps` is the number of irreducible representations in the output type (covariance degrees of freedom) Returns: mu (torch.Tensor): Mean of the distribution of shape (B, n). var (torch.Tensor): Variance of the distribution of shape (B, n). """ assert input.type == self.in_type, "Input type does not match the expected input type." # Extract the mean and covariance from the input if self.diagonal: mu, var = _equiv_mean_var_from_input(input.tensor, self.idx, self.Q2_T, self.y_type.size) else: raise NotImplementedError("Full covariance matrices are not implemented yet.") return mu, var
[docs] def get_distribution(self, input): """Returns the MultivariateNormal distribution.""" mu, var = self(input) return torch.distributions.MultivariateNormal(mu, torch.diag_embed(var))
[docs] def evaluate_output_shape(self, input_shape): """Output shape are vector of samples from the normal distribution""" return input_shape[0], self.y_type.size
[docs] def check_equivariance(self, atol=1e-5, rtol=1e-5): """Check equivariance of the module.""" B = 50 # Generate random input input = torch.randn(B, self.in_type.size) y = torch.randn(B, self.y_type.size) prob_Gy = [] for g in self.y_type.fibergroup.elements: # Transform the input g_input = self.in_type.transform_fibers(input, g) gy = self.y_type.transform_fibers(y, g) normal = self.get_distribution(self.in_type(g_input)) prob_Gy.append(normal.log_prob(gy)) prob_Gy = torch.stack(prob_Gy, dim=1) # Check that all probabilities are equal on group orbits assert torch.allclose(prob_Gy, prob_Gy.mean(dim=1, keepdim=True), atol=atol, rtol=rtol), ( "Probabilities are not invariant on group orbits" )
[docs] def export(self): """Exporting to a torch.nn.Module""" torch_e_normal = _EquivMultivariateNormal( idx=self.idx, Q2_T=self.Q2_T, dim_y=self.y_type.size, diagonal=self.diagonal, ) torch_e_normal.eval() # Set to eval mode return torch_e_normal
class _EquivMultivariateNormal(torch.nn.Module): """Utility class to export `EquivMultivariateNormal` to a standard PyTorch module.""" def __init__( self, idx: torch.Tensor, # shape (n,) – broadcast map irrep→component Q2_T: torch.Tensor, # shape (n, n) – (Q ** 2).T from escnn dim_y: int, # dim of the output space (= y_type.size) diagonal: bool = True, # only diagonal covariance matrices are implemented ): super().__init__() self.register_buffer("idx", idx.clone()) self.register_buffer("Q2_T", Q2_T.clone()) self.dim_y = dim_y self.diagonal = diagonal def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Compute the mean and variance of a multivariate normal distribution.""" if self.diagonal: mu, var = _equiv_mean_var_from_input(x, self.idx, self.Q2_T, self.dim_y) else: raise NotImplementedError("Full covariance matrices are not implemented yet.") return mu, var def get_distribution(self, x: torch.Tensor) -> torch.distributions.MultivariateNormal: """Returns the MultivariateNormal distribution.""" mu, var = self(x) return torch.distributions.MultivariateNormal(mu, torch.diag_embed(var)) def check_equivariance(self, in_type, y_type, atol=1e-5, rtol=1e-5): """Check equivariance of the module.""" B = 50 # Generate random input input = torch.randn(B, in_type.size) y = torch.randn(B, y_type.size) prob_Gy = [] for g in y_type.fibergroup.elements: # Transform the input g_input = in_type.transform_fibers(input, g) gy = y_type.transform_fibers(y, g) normal = self.get_distribution(g_input) prob_Gy.append(normal.log_prob(gy)) prob_Gy = torch.stack(prob_Gy, dim=1) # Check that all probabilities are equal on group orbits assert torch.allclose(prob_Gy, prob_Gy.mean(dim=1, keepdim=True), atol=atol, rtol=rtol), ( "Probabilities are not invariant on group orbits" ) if __name__ == "__main__": # Example usage from escnn.group import CyclicGroup, DihedralGroup, Icosahedral from torch.distributions import MultivariateNormal from symm_learning.models.emlp import EMLP G = CyclicGroup(3) x_type = FieldType(escnn.gspaces.no_base_space(G), representations=[G.regular_representation]) y_type = FieldType(escnn.gspaces.no_base_space(G), representations=[G.regular_representation] * 1) rep_x = x_type.representation G = rep_x.group rep_var = directsum([G.trivial_representation] * rep_x.size) e_normal = EquivMultivariateNormal(y_type, diagonal=True) nn = EMLP(in_type=x_type, out_type=e_normal.in_type) batch_size = 1 x = torch.randn(batch_size, x_type.size) y = torch.randn(batch_size, y_type.size) n_params = nn(x_type(x)) prob_Gx = [] for g in G.elements: gx = x_type.transform_fibers(x, g) gy = y_type.transform_fibers(y, g) out = nn(x_type(gx)) normal = e_normal.get_distribution(out) prob_Gx.append(normal.log_prob(gy)) prob_Gx = torch.stack(prob_Gx, dim=1) # Check that all probabilities are equal on group orbits assert torch.allclose(prob_Gx, prob_Gx.mean(dim=1, keepdim=True)), "Probabilities are not equal on group orbits" e_normal.check_equivariance(atol=1e-5, rtol=1e-5) torch_e_normal: _EquivMultivariateNormal = e_normal.export() torch_e_normal.check_equivariance(in_type=e_normal.in_type, y_type=y_type, atol=1e-5, rtol=1e-5)