Source code for symm_learning.stats

"""Statistics utilities for symmetric random variables with known group representations."""

from __future__ import annotations

import numpy as np
import torch
from escnn.group import Representation
from torch import Tensor

from symm_learning.linalg import invariant_orthogonal_projector
from symm_learning.representation_theory import isotypic_decomp_rep


[docs] def var_mean(x: Tensor, rep_x: Representation): """Compute the mean and variance of a symmetric random variable. Args: x: (:class:`torch.Tensor`) of shape :math:`(N, D_x)` containing the observations of the symmetric random variable rep_x: (:class:`~escnn.group.Representation`) representation of the symmetric random variable. Returns: (:class:`torch.Tensor`, :class:`torch.Tensor`): Mean and variance of the symmetric random variable. The mean is restricted to be in the trivial/G-invariant subspace of the symmetric vector space. The variance is constrained such that in the irrep-spectral basis, each G-irreducible subspace (i.e., each subspace associated with an irrep) has the same variance in all dimensions of that subspace. Shape: - **x**: :math:`(N, D_x)` or :math:`(N, D_x, T)` where N is the number of samples, D_x is the dimension of the symmetric random variable, and T is the sequence length (if applicable). - **Output**: A tuple containing the variance and the mean. The variance has shape :math:`(D_x,)` and the mean has shape :math:`(D_x,)`. If a sequence is provided (T dimension), the shapes are :math:`(D_x, T)`. """ assert x.ndim in (2, 3), f"Expected x to be a 2D or 3D tensor, got {x.ndim}D tensor" if "invariant_orthogonal_projector" not in rep_x.attributes: P_inv = invariant_orthogonal_projector(rep_x) rep_x.attributes["invariant_orthogonal_projector"] = P_inv else: P_inv = rep_x.attributes["invariant_orthogonal_projector"] if "Q_inv" not in rep_x.attributes: # Use cache Tensor if available. Q_inv = torch.tensor(rep_x.change_of_basis_inv, device=x.device, dtype=x.dtype) rep_x.attributes["Q_inv"] = Q_inv else: Q_inv = rep_x.attributes["Q_inv"] if "Q_squared" not in rep_x.attributes: # Use cache Tensor if available. Q = torch.tensor(rep_x.change_of_basis, device=x.device, dtype=x.dtype) Q_squared = Q.pow(2) rep_x.attributes["Q_squared"] = Q_squared else: Q_squared = rep_x.attributes["Q_squared"] x_flat = x if x.ndim == 2 else x.reshape(-1, x.shape[1]) mean_empirical = torch.mean(x_flat, dim=0) # Mean over batch as sequence length. # Project to the inv-subspace and map back to the original basis mean = torch.einsum("ij,j->i...", P_inv, mean_empirical) # Symmetry constrained variance computation. # The variance is constraint to be a single constant per each irreducible subspace. # Hence, we compute the empirical variance, and average within each irreducible subspace. n_samples = x_flat.shape[0] x_c_irrep_spectral = torch.einsum("ij,...j->...i", Q_inv.to(device=x_flat.device), x_flat - mean) var_spectral = torch.sum(x_c_irrep_spectral**2, dim=0) / (n_samples - 1) # Vectorized averaging over irreducible subspace dimensions if "irrep_dims" not in rep_x.attributes: irrep_dims = torch.tensor([rep_x.group.irrep(*irrep_id).size for irrep_id in rep_x.irreps], device=x.device) rep_x.attributes["irrep_dims"] = irrep_dims else: irrep_dims = rep_x.attributes["irrep_dims"].to(device=x.device) if "irrep_indices" not in rep_x.attributes: # Create indices for each irrep subspace: [0,0,0,1,1,2,2,2,2,...] for irrep dims [3,2,4,...] irrep_indices = torch.repeat_interleave(torch.arange(len(irrep_dims)), irrep_dims) rep_x.attributes["irrep_indices"] = irrep_indices else: irrep_indices = rep_x.attributes["irrep_indices"].to(device=x.device) # Compute average variance for each irrep subspace using scatter operations avg_vars = torch.zeros(len(irrep_dims), device=x.device, dtype=var_spectral.dtype) # Sum variances within each irrep subspace using scatter_add_: # For irrep_indices = [0,0,0,1,1,2,2,2,2] and var_spectral = [v0,v1,v2,v3,v4,v5,v6,v7,v8] # This computes: avg_vars[0] = v0+v1+v2, avg_vars[1] = v3+v4, avg_vars[2] = v5+v6+v7+v8 avg_vars.scatter_add_(0, irrep_indices, var_spectral) avg_vars = avg_vars / irrep_dims # Broadcast back to full spectral dimensions var_spectral = avg_vars[irrep_indices] var = torch.einsum("ij,...j->...i", Q_squared.to(device=x.device), var_spectral) return var, mean
def _isotypic_cov(x: Tensor, rep_x: Representation, y: Tensor = None, rep_y: Representation = None): r"""Cross-covariance between two **isotypic sub-spaces that share the same irrep**. If both signals live in :math:`\rho_X=\bigoplus_{i=1}^{m_x}\rho_k` and :math:`\rho_Y=\bigoplus_{i=1}^{m_y}\rho_k` (with :math:`\rho_k` of dimension *d*), every :math:`G`-equivariant linear map factorises as .. math:: \operatorname{Cov}(X,Y) \;=\;\mathbf Z_{XY}\otimes \mathbf I_d, \qquad \mathbf Z_{XY}\in\mathbb R^{m_y\times m_x}. We estimate the free matrix :math:`\mathbf Z_{XY}` by 1. **centering** (skipped if the irrep is trivial); 2. **reshaping** the data so that each copy of the irrep becomes one “channel” of length *d·N*; 3. **projecting** every :math:`d\times d` block onto the orthogonal basis of :math:`\mathrm{End}_G(\rho_k)` via Frobenius inner products (see `arXiv:2505.19809 <https://arxiv.org/abs/2505.19809>`_); 4. rebuilding the block matrix that respects the constraint above. When ``y is None`` the routine reduces to an **auto-covariance** and only the symmetric (identity) basis element is kept. Args: x (Tensor): shape :math:`(N,\; m_x d)` — samples drawn from ``rep_x``. rep_x (escnn.group.Representation): isotypic representation containing exactly one irrep type. y (Tensor, optional): shape :math:`(N,\; m_y d)` — samples drawn from ``rep_y``. If *None*, computes the auto-covariance of *x*. rep_y (escnn.group.Representation, optional): isotypic representation matching the irrep of ``rep_x``; ignored when *y* is *None*. Returns: (Tensor, Tensor): - C_xy: :math:`(m_y d,\; m_x d)` projected covariance. - Z_xy: :math:`(m_y,\; m_x,\; B)`, free coefficients of each cross-covariance between irrep subspaces, representing basis expansion coefficients in the basis of endomorphisms of the irrep subspaces. Where :math:`B = 1, 2, 4` for real, complex, quaternionic irreps, respectively. """ irrep_id = rep_x.irreps[0] # Irrep id of the isotypic subspace assert rep_x.size == x.shape[-1], f"Expected signal shape to be (..., {rep_x.size}) got {x.shape}" assert len(rep_x._irreps_multiplicities) == 1, f"Expected rep with a single irrep type, got {rep_x.irreps}" x_in_iso_basis = np.allclose(rep_x.change_of_basis_inv, np.eye(rep_x.size), atol=1e-6, rtol=1e-4) assert x_in_iso_basis, "Expected X to be in spectral/isotypic basis" if y is not None: assert len(rep_y._irreps_multiplicities) == 1, f"Expected rep with a single irrep type, got {rep_y.irreps}" assert rep_x.group == rep_y.group, f"{rep_x.group} != {rep_y.group}" assert irrep_id == rep_y.irreps[0], f"Irreps {irrep_id} != {rep_y.irreps[0]}. Hence Cxy=0" assert rep_y.size == y.shape[-1], f"Expected signal shape to be (..., {rep_y.size}) got {y.shape}" y_in_iso_basis = np.allclose(rep_y.change_of_basis_inv, np.eye(rep_y.size), atol=1e-6, rtol=1e-4) assert y_in_iso_basis, "Expected Y to be in spectral/isotypic basis" # Get information about the irreducible representation present in the isotypic subspace irrep_dim = rep_x.group.irrep(*irrep_id).size # irrep_end_basis := (dim(End(irrep)), dim(irrep), dim(irrep)) irrep_end_basis = torch.tensor(rep_x.group.irrep(*irrep_id).endomorphism_basis(), device=x.device, dtype=x.dtype) # if y is None Cxy = Cx is symmetric matrix. Hence it has non-zero entries only in the diagonal. if y is None: irrep_end_basis = irrep_end_basis[[0]] rep_y = rep_x # Use the same representation for Y y = x m_x = rep_x._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in X m_y = rep_y._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in Y x_iso, y_iso = x, y is_inv_subspace = irrep_id == rep_x.group.trivial_representation.id if is_inv_subspace: # Nothing to do, return empirical covariance. x_iso = x - torch.mean(x, dim=0, keepdim=True) y_iso = y - torch.mean(y, dim=0, keepdim=True) Cxy_iso = torch.einsum("...y,...x->yx", y_iso, x_iso) / (x_iso.shape[0] - 1) return Cxy_iso, Cxy_iso # Invariant subspace covariance is the same as the covariance matrix. # Compute empirical cross-covariance Cxy_iso = torch.einsum("...y,...x->yx", y_iso, x_iso) / (x_iso.shape[0] - 1) # ReshapCxy_isoe from (my * d, mx * d) to (my, mx, d, d) Cxy_irreps = Cxy_iso.view(m_y, irrep_dim, m_x, irrep_dim).permute(0, 2, 1, 3).contiguous() # Compute basis expansion coefficients of each irrep cross-covariance in basis of End(irrep) ======== # Frobenius inner product <C , Ψ_b> = Σ_{i,j} C_{ij} Ψ_b,ij Cxy_irreps_basis_coeff = torch.einsum("mnij,bij->mnb", Cxy_irreps, irrep_end_basis) # (m_y , m_x , B) # squared norms ‖Ψ_b‖² (only once, very small) basis_coeff_norms = torch.einsum("bij,bij->b", irrep_end_basis, irrep_end_basis) # (B,) Cxy_irreps_basis_coeff = Cxy_irreps_basis_coeff / basis_coeff_norms[None, None] Cxy_irreps = torch.einsum("...b,bij->...ij", Cxy_irreps_basis_coeff, irrep_end_basis) # (m_y , m_x , d , d) # Reshape to (my * d, mx * d) Cxy_iso = Cxy_irreps.permute(0, 2, 1, 3).reshape(m_y * irrep_dim, m_x * irrep_dim) return Cxy_iso, Cxy_irreps_basis_coeff
[docs] def cov(x: Tensor, y: Tensor, rep_x: Representation, rep_y: Representation): r"""Compute the covariance between two symmetric random variables. The covariance of r.v. can be computed from the orthogonal projections of the r.v. to each isotypic subspace. Hence, in the disentangled/isotypic basis the covariance can be computed in block-diagonal form: .. math:: \begin{align} \mathbf{C}_{xy} &= \mathbf{Q}_y^T (\bigoplus_{k} \mathbf{C}_{xy}^{(k)} )\mathbf{Q}_x \\ &= \mathbf{Q}_y^T ( \bigoplus_{k} \sum_{b\in \mathbb{B}_k} \mathbf{Z}_b^{(k)} \otimes \mathbf{b} ) \mathbf{Q}_x \\ \end{align} Where :math:`\mathbf{Q}_x^{\mathsf T}` and :math:`\mathbf{Q}_y^{\mathsf T}` are the change-of-basis matrices to the isotypic bases of :math:`\mathcal{X}` and :math:`\mathcal{Y}`, respectively; :math:`\mathbf{C}_{xy}^{(k)}` is the covariance restricted to the isotypic subspaces of type *k*; and :math:`\mathbf{Z}_b^{(k)}` are the free parameters—i.e. the expansion coefficients in the endomorphism basis :math:`\mathbb{B}_k` of the irreducible representation of type *k*. Args: x (Tensor): Realizations of a random variable :math:`X`. y (Tensor): Realizations of a random variable :math:`Y`. rep_x (Representation): The representation acting on the symmetric vector spaces :math:`\mathcal{X}`. rep_y (Representation): The representation acting on the symmetric vector spaces :math:`\mathcal{Y}`. Returns: Tensor: The covariance matrix between the two random variables, of shape :math:`(D_y, D_x)`. Shape: X: :math:`(N, D_x)` where :math:`D_x` is the dimension of the random variable X. Y: :math:`(N, D_y)` where :math:`D_y` is the dimension of the random variable Y. Output: :math:`(D_y, D_x)` """ # assert X.shape[0] == Y.shape[0], "Expected equal number of samples in X and Y" assert x.shape[1] == rep_x.size, f"Expected X shape (N, {rep_x.size}), got {x.shape}" assert y.shape[1] == rep_y.size, f"Expected Y shape (N, {rep_y.size}), got {y.shape}" assert x.shape[-1] == rep_x.size, f"Expected X shape (..., {rep_x.size}), got {x.shape}" assert y.shape[-1] == rep_y.size, f"Expected Y shape (..., {rep_y.size}), got {y.shape}" rep_X_iso = isotypic_decomp_rep(rep_x) rep_Y_iso = isotypic_decomp_rep(rep_y) # Changes of basis from the Disentangled/Isotypic-basis of X, and Y to the original basis. Qx = torch.tensor(rep_X_iso.change_of_basis, device=x.device, dtype=x.dtype) Qy = torch.tensor(rep_Y_iso.change_of_basis, device=y.device, dtype=y.dtype) rep_X_iso_subspaces = rep_X_iso.attributes["isotypic_reps"] rep_Y_iso_subspaces = rep_Y_iso.attributes["isotypic_reps"] # Get the dimensions of the isotypic subspaces of the same type in the input/output representations. iso_idx_X, iso_idx_Y = {}, {} x_dim = 0 for iso_id, rep_k in rep_X_iso_subspaces.items(): iso_idx_X[iso_id] = slice(x_dim, x_dim + rep_k.size) x_dim += rep_k.size y_dim = 0 for iso_id, rep_k in rep_Y_iso_subspaces.items(): iso_idx_Y[iso_id] = slice(y_dim, y_dim + rep_k.size) y_dim += rep_k.size X_iso = torch.einsum("ij,...j->...i", Qx.T, x) Y_iso = torch.einsum("ij,...j->...i", Qy.T, y) Cxy_iso = torch.zeros((rep_y.size, rep_x.size), dtype=x.dtype, device=x.device) for iso_id in rep_Y_iso_subspaces: if iso_id not in rep_X_iso_subspaces: continue # No covariance between the isotypic subspaces of different types. X_k = X_iso[..., iso_idx_X[iso_id]] Y_k = Y_iso[..., iso_idx_Y[iso_id]] rep_X_k = rep_X_iso_subspaces[iso_id] rep_Y_k = rep_Y_iso_subspaces[iso_id] # Cxy_k = D_xy_k ⊗ I_d [my * d x mx * d] Cxy_k, _ = _isotypic_cov(x=X_k, y=Y_k, rep_x=rep_X_k, rep_y=rep_Y_k) Cxy_iso[iso_idx_Y[iso_id], iso_idx_X[iso_id]] = Cxy_k # Change to the original basis Cxy = Qy.T @ Cxy_iso @ Qx return Cxy