distributions

class symm_learning.nn.distributions.EquivMultivariateNormal(y_type: FieldType, diagonal=True)[source]

G-equivariant multivariate normal.

Utility layer to parameterize a G-equivariant multivariate gaussian/normal distribution.

\[y \sim \mathcal{N} \bigl( \mu(x), \Sigma(x) \bigr),\]

Where x is the input to the layer parameterizing the mean and (free) degrees of freedom of the covariance matrix, constrained to satisfy:

\[ \begin{align}\begin{aligned}\rho_Y(g) \mu(x) = \mu(\rho_X(g) \cdot x)\\\rho_Y(g) \Sigma(x) \rho_Y(g)^{\top}= \Sigma(\rho_X(g) x) \quad \forall g \in G.\end{aligned}\end{align} \]

Such that:

\[P(y \mid x) = P(\rho_Y(g) y \mid \rho_X(g) x) \quad \forall g \in G.\]

The input of the layer is composed of the desired mean of the distribution and the log-variances of each irreducible subspace of the representation \(\rho_Y\) of the output. The number of log-variances varies with the number of irreducible subspaces of the representation, hence this layer is meant to be instantiated before the EquivaraintModule that will be used to parameterize the multivariate normal distribution. See the example below.

Parameters:
  • y_type (FieldType) – Field/feature type of X (the mean).

  • diagonal (bool, default True) – Only diagonal covariance matrices are implemented. Note these are not necessarily constant multiples of the identity.

  • Example

  • ---------

  • CyclicGroup (>>> from escnn.group import)

  • EMLP (>>> from symm_learning.models.emlp import)

  • CyclicGroup(3) (>>> G =)

  • FieldType(escnn.gspaces.no_base_space(G) (>>> y_type =)

  • representations=[G.regular_representation])

  • FieldType(escnn.gspaces.no_base_space(G)

  • 1) (representations=[G.regular_representation] *)

  • EquivMultivariateNormal(y_type (>>> e_normal =)

  • diagonal=True)

  • EMLP(in_type=x_type (>>> nn =)

  • out_type=e_normal.in_type)

  • torch.randn(1 (>>> x =)

  • x_type.size)

  • e_normal.get_distribution(nn(x_type(x))) (>>> dist =)

  • distribution (>>> # Sample from the)

  • dist.sample() (>>> y =)

check_equivariance(atol=1e-05, rtol=1e-05)[source]

Check equivariance of the module.

evaluate_output_shape(input_shape)[source]

Output shape are vector of samples from the normal distribution

export()[source]

Exporting to a torch.nn.Module

forward(input)[source]

Compute the mean and variance of a equivariant multivariate normal distribution

Parameters:
  • input (FieldType) – Input tensor of shape (B, n + n_irreps) where: B is the batch size, n is the size of

  • type (the output)

  • freedom) ((covariance degrees of)

Returns:

Mean of the distribution of shape (B, n). var (torch.Tensor): Variance of the distribution of shape (B, n).

Return type:

mu (torch.Tensor)

get_distribution(input)[source]

Returns the MultivariateNormal distribution.