eMultivariateNormal#

class eMultivariateNormal(out_rep, diagonal=True)[source]#

Bases: Module

Conditional Gaussian with \(\mathbb{G}\)-equivariant parameters.

This module maps parameter vectors in an input representation space to a Gaussian over \(\mathcal{Y}\) with representation \(\rho_{\mathcal{Y}}\):

\[\mathbf{y}\,|\,\mathbf{u} \sim \mathcal{N}\!\left(\boldsymbol{\mu}(\mathbf{u}),\mathbf{\Sigma}(\mathbf{u})\right).\]

The constraints are

\[\boldsymbol{\mu}(\rho_{\mathrm{in}}(g)\mathbf{u}) = \rho_{\mathcal{Y}}(g)\,\boldsymbol{\mu}(\mathbf{u}), \quad \mathbf{\Sigma}(\rho_{\mathrm{in}}(g)\mathbf{u}) = \rho_{\mathcal{Y}}(g)\,\mathbf{\Sigma}(\mathbf{u})\,\rho_{\mathcal{Y}}(g)^T, \ \forall g\in\mathbb{G},\]

implying orbit-wise density invariance

\[p(\mathbf{y}\mid\mathbf{u}) = p(\rho_{\mathcal{Y}}(g)\mathbf{y}\mid \rho_{\mathrm{in}}(g)\mathbf{u}).\]

Implementation details:

  • The first out_rep.size coordinates of the input are interpreted as \(\boldsymbol{\mu}\).

  • Remaining coordinates are log-variances, one per irreducible copy in \(\rho_{\mathcal{Y}}\).

  • Only diagonal covariances are implemented. In the irrep-spectral basis, each irrep copy uses one scalar variance shared by all dimensions of that copy.

Parameters:
  • out_rep (Representation) – Representation \(\rho_{\mathcal{Y}}\) describing the output space \(\mathcal{Y}\).

  • diagonal (bool) – Only diagonal covariance matrices are implemented. These are not necessarily constant multiples of identity. Default: True.

in_rep#

Input representation \(\rho_{\mathrm{in}}=\rho_{\mathcal{Y}}\oplus n_{\mathrm{irr}}\cdot\hat{\rho}_{\mathrm{triv}}\) carrying mean and covariance DoFs.

out_rep#

Output representation \(\rho_{\mathcal{Y}}\).

n_cov_params#

Number of independent covariance parameters (equals the number of irreps in out_rep).

Example

>>> from escnn.group import CyclicGroup
>>> from symm_learning.models.emlp import eMLP
>>> G = CyclicGroup(3)
>>> rep_x = G.regular_representation
>>> rep_y = G.regular_representation
>>> e_normal = eMultivariateNormal(out_rep=rep_y, diagonal=True)
>>> # Create an eMLP that outputs mean + cov params
>>> nn = eMLP(in_rep=rep_x, out_rep=e_normal.in_rep, hidden_units=[32])
>>> x = torch.randn(1, rep_x.size)
>>> dist = e_normal(nn(x))  # Returns torch.distributions.MultivariateNormal
>>> y = dist.sample()  # Sample from the distribution
check_equivariance(atol=1e-05, rtol=1e-05)[source]#

Verify that the distribution satisfies the equivariance constraint.

Checks \(p(\mathbf{y} \mid \mathbf{u}) = p(\rho_{\mathcal{Y}}(g)\mathbf{y} \mid \rho_{\mathrm{in}}(g)\mathbf{u})\) for sampled group elements.

Parameters:
  • atol (float) – Absolute tolerance for the equivariance check.

  • rtol (float) – Relative tolerance for the equivariance check.

Raises:

AssertionError – If the distribution is not equivariant within the given tolerances.

Return type:

None

forward(input)[source]#

Build MultivariateNormal from equivariant DoFs.

Parameters:

input (Tensor) – Tensor of shape (..., in_rep.size) containing the mean and log-variance parameters. The first out_rep.size elements are the mean, and the remaining n_cov_params elements are the log-variances.

Returns:

Gaussian with mean in \(\mathcal{Y}\) and diagonal covariance satisfying the constraints described in eMultivariateNormal.

Return type:

MultivariateNormal