eMultivariateNormal#
- class eMultivariateNormal(out_rep, diagonal=True)[source]#
Bases:
ModuleConditional Gaussian with \(\mathbb{G}\)-equivariant parameters.
This module maps parameter vectors in an input representation space to a Gaussian over \(\mathcal{Y}\) with representation \(\rho_{\mathcal{Y}}\):
\[\mathbf{y}\,|\,\mathbf{u} \sim \mathcal{N}\!\left(\boldsymbol{\mu}(\mathbf{u}),\mathbf{\Sigma}(\mathbf{u})\right).\]The constraints are
\[\boldsymbol{\mu}(\rho_{\mathrm{in}}(g)\mathbf{u}) = \rho_{\mathcal{Y}}(g)\,\boldsymbol{\mu}(\mathbf{u}), \quad \mathbf{\Sigma}(\rho_{\mathrm{in}}(g)\mathbf{u}) = \rho_{\mathcal{Y}}(g)\,\mathbf{\Sigma}(\mathbf{u})\,\rho_{\mathcal{Y}}(g)^T, \ \forall g\in\mathbb{G},\]implying orbit-wise density invariance
\[p(\mathbf{y}\mid\mathbf{u}) = p(\rho_{\mathcal{Y}}(g)\mathbf{y}\mid \rho_{\mathrm{in}}(g)\mathbf{u}).\]Implementation details:
The first
out_rep.sizecoordinates of the input are interpreted as \(\boldsymbol{\mu}\).Remaining coordinates are log-variances, one per irreducible copy in \(\rho_{\mathcal{Y}}\).
Only diagonal covariances are implemented. In the irrep-spectral basis, each irrep copy uses one scalar variance shared by all dimensions of that copy.
- Parameters:
out_rep (
Representation) – Representation \(\rho_{\mathcal{Y}}\) describing the output space \(\mathcal{Y}\).diagonal (
bool) – Only diagonal covariance matrices are implemented. These are not necessarily constant multiples of identity. Default:True.
- in_rep#
Input representation \(\rho_{\mathrm{in}}=\rho_{\mathcal{Y}}\oplus n_{\mathrm{irr}}\cdot\hat{\rho}_{\mathrm{triv}}\) carrying mean and covariance DoFs.
- out_rep#
Output representation \(\rho_{\mathcal{Y}}\).
- n_cov_params#
Number of independent covariance parameters (equals the number of irreps in
out_rep).
Example
>>> from escnn.group import CyclicGroup >>> from symm_learning.models.emlp import eMLP >>> G = CyclicGroup(3) >>> rep_x = G.regular_representation >>> rep_y = G.regular_representation >>> e_normal = eMultivariateNormal(out_rep=rep_y, diagonal=True) >>> # Create an eMLP that outputs mean + cov params >>> nn = eMLP(in_rep=rep_x, out_rep=e_normal.in_rep, hidden_units=[32]) >>> x = torch.randn(1, rep_x.size) >>> dist = e_normal(nn(x)) # Returns torch.distributions.MultivariateNormal >>> y = dist.sample() # Sample from the distribution
- check_equivariance(atol=1e-05, rtol=1e-05)[source]#
Verify that the distribution satisfies the equivariance constraint.
Checks \(p(\mathbf{y} \mid \mathbf{u}) = p(\rho_{\mathcal{Y}}(g)\mathbf{y} \mid \rho_{\mathrm{in}}(g)\mathbf{u})\) for sampled group elements.
- Parameters:
- Raises:
AssertionError – If the distribution is not equivariant within the given tolerances.
- Return type:
- forward(input)[source]#
Build
MultivariateNormalfrom equivariant DoFs.- Parameters:
input (
Tensor) – Tensor of shape(..., in_rep.size)containing the mean and log-variance parameters. The firstout_rep.sizeelements are the mean, and the remainingn_cov_paramselements are the log-variances.- Returns:
Gaussian with mean in \(\mathcal{Y}\) and diagonal covariance satisfying the constraints described in
eMultivariateNormal.- Return type: