Source code for symm_learning.nn.distributions

from __future__ import annotations

import torch
from escnn.group import Representation
from torch.distributions import MultivariateNormal

from symm_learning.representation_theory import direct_sum


def _equiv_mean_var_from_input(
    input: torch.Tensor,
    idx: torch.Tensor,
    Q2_T: torch.Tensor,
    dim_y: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Extract mean and variance from the input tensor."""
    mu = input[..., :dim_y]  # (B, n)
    log_eigvals = input[..., dim_y:]  # (B, n_irreps)
    var_irrep_spectral_basis = torch.exp(log_eigvals[..., idx]) + 1e-6  # (B, n)
    var = var_irrep_spectral_basis @ Q2_T  # (B, n)
    return mu, var


[docs] class eMultivariateNormal(torch.nn.Module): r"""Conditional Gaussian with :math:`\mathbb{G}`-equivariant parameters. This module maps parameter vectors in an input representation space to a Gaussian over :math:`\mathcal{Y}` with representation :math:`\rho_{\mathcal{Y}}`: .. math:: \mathbf{y}\,|\,\mathbf{u} \sim \mathcal{N}\!\left(\boldsymbol{\mu}(\mathbf{u}),\mathbf{\Sigma}(\mathbf{u})\right). The constraints are .. math:: \boldsymbol{\mu}(\rho_{\mathrm{in}}(g)\mathbf{u}) = \rho_{\mathcal{Y}}(g)\,\boldsymbol{\mu}(\mathbf{u}), \quad \mathbf{\Sigma}(\rho_{\mathrm{in}}(g)\mathbf{u}) = \rho_{\mathcal{Y}}(g)\,\mathbf{\Sigma}(\mathbf{u})\,\rho_{\mathcal{Y}}(g)^T, \ \forall g\in\mathbb{G}, implying orbit-wise density invariance .. math:: p(\mathbf{y}\mid\mathbf{u}) = p(\rho_{\mathcal{Y}}(g)\mathbf{y}\mid \rho_{\mathrm{in}}(g)\mathbf{u}). Implementation details: - The first ``out_rep.size`` coordinates of the input are interpreted as :math:`\boldsymbol{\mu}`. - Remaining coordinates are log-variances, one per irreducible copy in :math:`\rho_{\mathcal{Y}}`. - Only diagonal covariances are implemented. In the irrep-spectral basis, each irrep copy uses one scalar variance shared by all dimensions of that copy. Args: out_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{Y}}` describing the output space :math:`\mathcal{Y}`. diagonal: Only diagonal covariance matrices are implemented. These are not necessarily constant multiples of identity. Default: ``True``. Attributes: in_rep: Input representation :math:`\rho_{\mathrm{in}}=\rho_{\mathcal{Y}}\oplus n_{\mathrm{irr}}\cdot\hat{\rho}_{\mathrm{triv}}` carrying mean and covariance DoFs. out_rep: Output representation :math:`\rho_{\mathcal{Y}}`. n_cov_params: Number of independent covariance parameters (equals the number of irreps in ``out_rep``). Example: >>> from escnn.group import CyclicGroup >>> from symm_learning.models.emlp import eMLP >>> G = CyclicGroup(3) >>> rep_x = G.regular_representation >>> rep_y = G.regular_representation >>> e_normal = eMultivariateNormal(out_rep=rep_y, diagonal=True) >>> # Create an eMLP that outputs mean + cov params >>> nn = eMLP(in_rep=rep_x, out_rep=e_normal.in_rep, hidden_units=[32]) >>> x = torch.randn(1, rep_x.size) >>> dist = e_normal(nn(x)) # Returns torch.distributions.MultivariateNormal >>> y = dist.sample() # Sample from the distribution """ def __init__(self, out_rep: Representation, diagonal: bool = True): super().__init__() if not diagonal: raise NotImplementedError("Full covariance matrices are not implemented yet.") self.diagonal = diagonal self.out_rep = out_rep G = out_rep.group # ----- irrep metadata ------------------------------------------------ self.irrep_dims = torch.tensor([G.irrep(*irr).size for irr in out_rep.irreps], dtype=torch.long) # index vector that broadcasts irrep-scalars to component level idx = [i for i, d in enumerate(self.irrep_dims) for _ in range(d)] self.register_buffer("idx", torch.tensor(idx, dtype=torch.long)) self.n_cov_params = len(out_rep.irreps) # Number of params for the covariance matrix # ----- change-of-basis (irrep_spectral → user) ----------------------------- Q = torch.tensor(out_rep.change_of_basis, dtype=torch.get_default_dtype()) self.register_buffer("Q2_T", (Q.pow(2)).t()) # (n, n) transposed # ----- Group action on the degrees of freedom of the Cov matrix ------------ rep_cov_dof = direct_sum([G.trivial_representation] * len(out_rep.irreps)) self.in_rep = direct_sum([out_rep, rep_cov_dof])
[docs] def forward(self, input: torch.Tensor) -> MultivariateNormal: r"""Build :class:`~torch.distributions.multivariate_normal.MultivariateNormal` from equivariant DoFs. Args: input (:class:`~torch.Tensor`): Tensor of shape ``(..., in_rep.size)`` containing the mean and log-variance parameters. The first ``out_rep.size`` elements are the mean, and the remaining ``n_cov_params`` elements are the log-variances. Returns: :class:`~torch.distributions.multivariate_normal.MultivariateNormal`: Gaussian with mean in :math:`\mathcal{Y}` and diagonal covariance satisfying the constraints described in :class:`eMultivariateNormal`. """ if input.shape[-1] != self.in_rep.size: raise ValueError(f"Expected last dimension {self.in_rep.size}, got {input.shape[-1]}") if self.diagonal: mu, var = _equiv_mean_var_from_input(input, self.idx, self.Q2_T, self.out_rep.size) else: raise NotImplementedError("Full covariance matrices are not implemented yet.") return MultivariateNormal(mu, torch.diag_embed(var))
[docs] def check_equivariance(self, atol: float = 1e-5, rtol: float = 1e-5) -> None: # noqa: D301 r"""Verify that the distribution satisfies the equivariance constraint. Checks :math:`p(\mathbf{y} \mid \mathbf{u}) = p(\rho_{\mathcal{Y}}(g)\mathbf{y} \mid \rho_{\mathrm{in}}(g)\mathbf{u})` for sampled group elements. Args: atol (:class:`float`): Absolute tolerance for the equivariance check. rtol (:class:`float`): Relative tolerance for the equivariance check. Raises: AssertionError: If the distribution is not equivariant within the given tolerances. """ B = 50 G = self.out_rep.group # Generate random input input = torch.randn(B, self.in_rep.size) y = torch.randn(B, self.out_rep.size) prob_Gy = [] for g in G.elements: # Transform input: x -> rho_in(g) x rho_in_g = torch.tensor(self.in_rep(g), dtype=torch.get_default_dtype()) g_input = input @ rho_in_g.T # Transform output: y -> rho_out(g) y rho_out_g = torch.tensor(self.out_rep(g), dtype=torch.get_default_dtype()) gy = y @ rho_out_g.T normal = self(g_input) prob_Gy.append(normal.log_prob(gy)) prob_Gy = torch.stack(prob_Gy, dim=1) # Check that all probabilities are equal on group orbits assert torch.allclose(prob_Gy, prob_Gy.mean(dim=1, keepdim=True), atol=atol, rtol=rtol), ( "Probabilities are not invariant on group orbits" )
if __name__ == "__main__": # Example usage from escnn.group import CyclicGroup from symm_learning.models.emlp import eMLP G = CyclicGroup(3) rep_x = G.regular_representation rep_y = G.regular_representation e_normal = eMultivariateNormal(out_rep=rep_y, diagonal=True) nn = eMLP(in_rep=rep_x, out_rep=e_normal.in_rep, hidden_units=[32]) batch_size = 1 x = torch.randn(batch_size, rep_x.size) y = torch.randn(batch_size, rep_y.size) params = nn(x) prob_Gx = [] for g in G.elements: # Transform input: x -> rho_x(g) x rho_x_g = torch.tensor(rep_x(g), dtype=torch.get_default_dtype()) gx = x @ rho_x_g.T # Transform output: y -> rho_y(g) y rho_y_g = torch.tensor(rep_y(g), dtype=torch.get_default_dtype()) gy = y @ rho_y_g.T out = nn(gx) normal = e_normal(out) prob_Gx.append(normal.log_prob(gy)) prob_Gx = torch.stack(prob_Gx, dim=1) # Check that all probabilities are equal on group orbits assert torch.allclose(prob_Gx, prob_Gx.mean(dim=1, keepdim=True)), "Probabilities are not equal on group orbits" e_normal.check_equivariance(atol=1e-5, rtol=1e-5) print("Equivariance check passed!")