distributions
- class symm_learning.nn.distributions.EquivMultivariateNormal(y_type: FieldType, diagonal=True)[source]
G-equivariant multivariate normal.
Utility layer to parameterize a G-equivariant multivariate gaussian/normal distribution.
\[y \sim \mathcal{N} \bigl( \mu(x), \Sigma(x) \bigr),\]Where x is the input to the layer parameterizing the mean and (free) degrees of freedom of the covariance matrix, constrained to satisfy:
\[ \begin{align}\begin{aligned}\rho_Y(g) \mu(x) = \mu(\rho_X(g) \cdot x)\\\rho_Y(g) \Sigma(x) \rho_Y(g)^{\top}= \Sigma(\rho_X(g) x) \quad \forall g \in G.\end{aligned}\end{align} \]Such that:
\[P(y \mid x) = P(\rho_Y(g) y \mid \rho_X(g) x) \quad \forall g \in G.\]The input of the layer is composed of the desired mean of the distribution and the log-variances of each irreducible subspace of the representation \(\rho_Y\) of the output. The number of log-variances varies with the number of irreducible subspaces of the representation, hence this layer is meant to be instantiated before the EquivaraintModule that will be used to parameterize the multivariate normal distribution. See the example below.
- Parameters:
y_type (FieldType) – Field/feature type of X (the mean).
diagonal (bool, default
True
) – Only diagonal covariance matrices are implemented. Note these are not necessarily constant multiples of the identity.Example
---------
CyclicGroup (>>> from escnn.group import)
EMLP (>>> from symm_learning.models.emlp import)
CyclicGroup(3) (>>> G =)
FieldType(escnn.gspaces.no_base_space(G) (>>> y_type =)
representations=[G.regular_representation])
FieldType(escnn.gspaces.no_base_space(G)
1) (representations=[G.regular_representation] *)
EquivMultivariateNormal(y_type (>>> e_normal =)
diagonal=True)
EMLP(in_type=x_type (>>> nn =)
out_type=e_normal.in_type)
torch.randn(1 (>>> x =)
x_type.size)
e_normal.get_distribution(nn(x_type(x))) (>>> dist =)
distribution (>>> # Sample from the)
dist.sample() (>>> y =)
- evaluate_output_shape(input_shape)[source]
Output shape are vector of samples from the normal distribution
- forward(input)[source]
Compute the mean and variance of a equivariant multivariate normal distribution
- Parameters:
input (FieldType) – Input tensor of shape (B, n + n_irreps) where: B is the batch size, n is the size of
type (the output)
freedom) ((covariance degrees of)
- Returns:
Mean of the distribution of shape (B, n). var (torch.Tensor): Variance of the distribution of shape (B, n).
- Return type:
mu (torch.Tensor)