Source code for symm_learning.stats

"""Symmetric Learning - Statistics Utilities.

Functions for computing statistics of symmetric random variables that respect group
symmetry constraints.

Functions
---------
mean
    Compute the mean projected onto the G-invariant subspace.
var
    Compute variance respecting symmetry structure.
var_mean
    Compute variance and mean together efficiently.
cov
    Compute covariance between symmetric random variables.
"""

from __future__ import annotations

import torch
from escnn.group import Representation
from torch import Tensor

from symm_learning.linalg import equiv_orthogonal_projection, invariant_orthogonal_projector


[docs] def mean(x: Tensor, rep_x: Representation) -> Tensor: r"""Estimate the :math:`\mathbb{G}`-invariant mean of a random variable. Let :math:`\mathbf{X}: \Omega \to \mathcal{X}` be a random variable taking values in the symmetric vector space :math:`\mathcal{X}`, with group representation :math:`\rho_{\mathcal{X}}:\mathbb{G}\to\mathrm{GL}(\mathcal{X})`, and marginal density :math:`\mathbb{P}_{\mathbf{X}}`. Under the assumption that this marginal is invariant under the group action (i.e., a point and all its symmetric points have equal likelihood under the marginal), formally: .. math:: \mathbb{P}_{\mathbf{X}}(\mathbf{x}) = \mathbb{P}_{\mathbf{X}}\!\left(\rho_{\mathcal{X}}(g)\mathbf{x}\right), \quad \forall \mathbf{x}\in\mathcal{X},\ \forall g\in\mathbb{G}, the true mean satisfies .. math:: \mathbb{E}[\mathbf{X}] = \rho_{\mathcal{X}}(g)\,\mathbb{E}[\mathbf{X}], \quad \forall g\in\mathbb{G}, hence :math:`\mathbb{E}[\mathbf{X}] \in \mathcal{X}^{\text{inv}}`. Implementation: from samples :math:`\{\mathbf{x}^{(n)}\}_{n=1}^N`, we first compute the empirical mean .. math:: \widehat{\mathbb{E}}[\mathbf{X}] = \frac{1}{N}\sum_{n=1}^N \mathbf{x}^{(n)}, and then project it onto the invariant subspace: .. math:: \widehat{\mathbb{E}}_{\mathbb{G}}[\mathbf{X}] = \mathbf{P}_{\mathrm{inv}}\,\widehat{\mathbb{E}}[\mathbf{X}], \quad \mathbf{P}_{\mathrm{inv}} = \mathbf{Q}\mathbf{S}\mathbf{Q}^T, where :math:`\mathbf{S}` selects trivial-irrep coordinates in the irrep-spectral basis. (see :func:`~symm_learning.linalg.invariant_orthogonal_projector`). Under the repository's canonical isotypic ordering, this corresponds to the first isotypic block when present. Args: x: (:class:`~torch.Tensor`) samples of :math:`\mathbf{X}` with shape :math:`(N,D_x)` or :math:`(N,D_x,T)`; when a time axis is present it is folded into the sample axis. rep_x: (:class:`~escnn.group.Representation`) representation :math:`\rho_{\mathcal{X}}` on :math:`\mathcal{X}`. Returns: (:class:`~torch.Tensor`): Invariant mean vector in :math:`\mathcal{X}`. Shape: - **x**: :math:`(N,D_x)` or :math:`(N,D_x,T)`. - **Output**: :math:`(D_x,)`. Note: For repeated calls with the same representation object ``rep_x``, this function caches :math:`\mathbf{P}_{\mathrm{inv}}` in ``rep_x.attributes["invariant_orthogonal_projector"]``. """ assert x.ndim in (2, 3), f"Expected x to be a 2D or 3D tensor, got {x.ndim}D tensor" if "invariant_orthogonal_projector" not in rep_x.attributes: P_inv = invariant_orthogonal_projector(rep_x) rep_x.attributes["invariant_orthogonal_projector"] = P_inv else: P_inv = rep_x.attributes["invariant_orthogonal_projector"] x_flat = x if x.ndim == 2 else x.reshape(-1, x.shape[1]) mean_empirical = torch.mean(x_flat, dim=0) # Mean over batch as sequence length. # Project to the inv-subspace and map back to the original basis mean_result = torch.einsum("ij,j->i...", P_inv.to(device=x_flat.device, dtype=mean_empirical.dtype), mean_empirical) return mean_result
[docs] def var(x: Tensor, rep_x: Representation, center: Tensor = None) -> Tensor: r"""Estimate the symmetry-constrained variance of :math:`\mathbf{X}:\Omega\to\mathcal{X}`. Let :math:`\mathbf{X}: \Omega \to \mathcal{X}` be a random variable taking values in the symmetric vector space :math:`\mathcal{X}`, with group representation :math:`\rho_{\mathcal{X}}:\mathbb{G}\to\mathrm{GL}(\mathcal{X})`, and marginal density :math:`\mathbb{P}_{\mathbf{X}}`. Under the assumption that this marginal is invariant under the group action (i.e., a point and all its symmetric points have equal likelihood under the marginal), formally: .. math:: \mathbb{P}_{\mathbf{X}}(\mathbf{x}) = \mathbb{P}_{\mathbf{X}}\!\left(\rho_{\mathcal{X}}(g)\mathbf{x}\right), \quad \forall \mathbf{x}\in\mathcal{X},\ \forall g\in\mathbb{G}, the true variance in the irrep-spectral basis (:func:`~symm_learning.representation_theory.isotypic_decomp_rep`) is constant within each irreducible copy: .. math:: \operatorname{Var}(\hat{\mathbf{X}}_{k,i,1}) = \cdots = \operatorname{Var}(\hat{\mathbf{X}}_{k,i,d_k}) = \sigma^2_{k,i}. Implementation: given samples :math:`\{\mathbf{x}^{(n)}\}_{n=1}^{N}`, we compute: 1. Centering (using provided center or :func:`mean`): .. math:: \widehat{\boldsymbol{\mu}} = \begin{cases} \texttt{center}, & \text{if provided} \\ \widehat{\mathbb{E}}_{\mathbb{G}}[\mathbf{X}], & \text{otherwise} \end{cases} 2. Empirical spectral variance: .. math:: \hat{\mathbf{x}}^{(n)} = \mathbf{Q}^{T}(\mathbf{x}^{(n)}-\widehat{\boldsymbol{\mu}}),\qquad \widehat{v}_{j} = \frac{1}{N-1}\sum_{n=1}^{N}\left(\hat{x}^{(n)}_{j}\right)^2. 3. Irrep-wise averaging for each copy :math:`(k,i)`: .. math:: \widehat{\sigma}^{2}_{k,i} = \frac{1}{d_k}\sum_{r=1}^{d_k}\widehat{v}_{k,i,r}, \qquad \widehat{v}_{k,i,1}=\cdots=\widehat{v}_{k,i,d_k}:=\widehat{\sigma}^{2}_{k,i}. 4. Mapping back to the original basis: .. math:: \widehat{\operatorname{Var}}(\mathbf{X}) = \mathbf{Q}^{\odot 2}\,\widehat{\mathbf{v}}, where :math:`\mathbf{Q}^{\odot 2}` is the elementwise square of :math:`\mathbf{Q}` and :math:`\widehat{\mathbf{v}}` denotes the broadcast spectral variance vector after step 3. Args: x: (:class:`~torch.Tensor`) samples with shape :math:`(N,D_x)` or :math:`(N,D_x,T)`; the optional time axis is folded into samples. rep_x: (:class:`~escnn.group.Representation`) representation :math:`\rho_{\mathcal{X}}`. center: (:class:`~torch.Tensor`, optional) Center for variance computation. If None, computes the mean. Returns: (:class:`~torch.Tensor`): Variance vector in the original basis, consistent with the irrep-wise constraint above. Shape: - **x**: :math:`(N,D_x)` or :math:`(N,D_x,T)`. - **center**: :math:`(D_x,)` if provided. - **Output**: :math:`(D_x,)`. Note: For repeated calls with the same representation object ``rep_x``, this function caches and reuses: ``Q_inv``, ``Q_squared``, ``irrep_dims``, and ``irrep_indices`` in ``rep_x.attributes``. """ assert x.ndim in (2, 3), f"Expected x to be a 2D or 3D tensor, got {x.ndim}D tensor" if "Q_inv" not in rep_x.attributes: # Use cache Tensor if available. Q_inv = torch.tensor(rep_x.change_of_basis_inv, device=x.device, dtype=x.dtype) rep_x.attributes["Q_inv"] = Q_inv else: Q_inv = rep_x.attributes["Q_inv"] if "Q_squared" not in rep_x.attributes: # Use cache Tensor if available. Q = torch.tensor(rep_x.change_of_basis, device=x.device, dtype=x.dtype) Q_squared = Q.pow(2) rep_x.attributes["Q_squared"] = Q_squared else: Q_squared = rep_x.attributes["Q_squared"] x_flat = x if x.ndim == 2 else x.reshape(-1, x.shape[1]) # Use provided center or compute mean if center is None: center = mean(x, rep_x) # Symmetry constrained variance computation. # The variance is constraint to be a single constant per each irreducible subspace. # Hence, we compute the empirical variance, and average within each irreducible subspace. n_samples = x_flat.shape[0] x_c_irrep_spectral = torch.einsum("ij,...j->...i", Q_inv.to(device=x_flat.device), x_flat - center) var_spectral = torch.sum(x_c_irrep_spectral**2, dim=0) / (n_samples - 1) # Vectorized averaging over irreducible subspace dimensions if "irrep_dims" not in rep_x.attributes: irrep_dims = torch.tensor([rep_x.group.irrep(*irrep_id).size for irrep_id in rep_x.irreps], device=x.device) rep_x.attributes["irrep_dims"] = irrep_dims else: irrep_dims = rep_x.attributes["irrep_dims"].to(device=x.device) if "irrep_indices" not in rep_x.attributes: # Create indices for each irrep subspace: [0,0,0,1,1,2,2,2,2,...] for irrep dims [3,2,4,...] indices = torch.arange(len(irrep_dims)).to(device=irrep_dims.device) irrep_indices = torch.repeat_interleave(indices, irrep_dims) rep_x.attributes["irrep_indices"] = irrep_indices else: irrep_indices = rep_x.attributes["irrep_indices"].to(device=x.device) # Compute average variance for each irrep subspace using scatter operations avg_vars = torch.zeros(len(irrep_dims), device=x.device, dtype=var_spectral.dtype) # Sum variances within each irrep subspace using scatter_add_: # For irrep_indices = [0,0,0,1,1,2,2,2,2] and var_spectral = [v0,v1,v2,v3,v4,v5,v6,v7,v8] # This computes: avg_vars[0] = v0+v1+v2, avg_vars[1] = v3+v4, avg_vars[2] = v5+v6+v7+v8 avg_vars.scatter_add_(0, irrep_indices, var_spectral) avg_vars = avg_vars / irrep_dims # Broadcast back to full spectral dimensions var_spectral = avg_vars[irrep_indices] var_result = torch.einsum("ij,...j->...i", Q_squared.to(device=x.device), var_spectral) return var_result
[docs] def var_mean(x: Tensor, rep_x: Representation): r"""Compute :func:`var` and :func:`mean` under symmetry constraints. Args: x: (:class:`~torch.Tensor`) samples with shape :math:`(N,D_x)` or :math:`(N,D_x,T)`. rep_x: (:class:`~escnn.group.Representation`) representation :math:`\rho_{\mathcal{X}}`. Returns: (:class:`~torch.Tensor`, :class:`~torch.Tensor`): Tuple ``(var, mean)`` where mean is projected to :math:`\mathcal{X}^{\text{inv}}` and variance satisfies irrep-wise isotropy in spectral basis. Shape: - **x**: :math:`(N,D_x)` or :math:`(N,D_x,T)`. - **Output**: ``(var, mean)`` both with shape :math:`(D_x,)`. Note: This function reuses the same caches as :func:`mean` and :func:`var` when called repeatedly with the same representation object ``rep_x``. """ # Compute mean first mean_result = mean(x, rep_x) # Compute variance using the computed mean var_result = var(x, rep_x, center=mean_result) return var_result, mean_result
[docs] def cov( x: Tensor, y: Tensor, rep_x: Representation, rep_y: Representation, uncentered: bool = False, ): r"""Compute symmetry-aware cross-covariance. Let :math:`\mathbf{X}:\Omega\to\mathcal{X}` and :math:`\mathbf{Y}:\Omega\to\mathcal{Y}` be two :math:`\mathbb{G}`-invariant random variables endowed with the :math:`\mathbb{G}` representations :math:`\rho_{\mathcal{X}}` and :math:`\rho_{\mathcal{Y}}` respectively. This function computes the symmetry-aware cross-covariance which by construction is a :math:`\mathbb{G}`-equivariant linear map in :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})`. Implementation: To achieve this we first compute the symmetry-agnostic empirical covariance :math:`\mathbf{C}^{\text{raw}}_{yx} = \frac{1}{N-1}\sum_{n=1}^{N}\mathbf{y}^{\star}_n (\mathbf{x}^{\star}_n)^\top`. By default (:attr:`uncentered=False`), centered variables use invariant means from :func:`mean`: :math:`\mathbf{x}^{\star}_n = \mathbf{x}_n - \boldsymbol{\mu}_x`, :math:`\mathbf{y}^{\star}_n = \mathbf{y}_n - \boldsymbol{\mu}_y`, with :math:`\boldsymbol{\mu}_x = \widehat{\mathbb{E}}_{\mathbb{G}}[\mathbf{X}]`, :math:`\boldsymbol{\mu}_y = \widehat{\mathbb{E}}_{\mathbb{G}}[\mathbf{Y}]`. If :attr:`uncentered=True`, :math:`\mathbf{x}^{\star}_n=\mathbf{x}_n` and :math:`\mathbf{y}^{\star}_n=\mathbf{y}_n`. The returned covariance is the orthogonal projection of :math:`\mathbf{C}^{\text{raw}}_{yx}` onto :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})` via :func:`~symm_learning.linalg.equiv_orthogonal_projection`: .. math:: \mathbf{C}_{yx} = \Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{C}^{\text{raw}}_{yx}). This orthogonal projector is equivalent to the Reynolds/group-average operator: .. math:: \Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{A}) = \frac{1}{|\mathbb{G}|}\sum_{g\in\mathbb{G}} \rho_{\mathcal{Y}}(g)\,\mathbf{A}\,\rho_{\mathcal{X}}(g^{-1}). Args: x (:class:`~torch.Tensor`): Samples of :math:`\mathbf{X}`. y (:class:`~torch.Tensor`): Samples of :math:`\mathbf{Y}`. rep_x (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{X}}`. rep_y (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{Y}}`. uncentered (:class:`bool`): If ``False`` (default), subtract invariant means before covariance computation. If ``True``, compute the uncentered second moment and project it. Returns: :class:`~torch.Tensor`: Projected cross-covariance :math:`\mathbf{C}_{yx}` with shape :math:`(D_y, D_x)`. Shape: - **x**: :math:`(N, D_x)`. With `N` denoting the number of samples and `D_x` the dimension of the representation space of `x`. - **y**: :math:`(N, D_y)`. With `D_y` the dimension of the representation space of `y`. - **Output**: :math:`(D_y, D_x)`. """ assert x.shape[1] == rep_x.size, f"Expected X shape (N, {rep_x.size}), got {x.shape}" assert y.shape[1] == rep_y.size, f"Expected Y shape (N, {rep_y.size}), got {y.shape}" assert x.shape[-1] == rep_x.size, f"Expected X shape (..., {rep_x.size}), got {x.shape}" assert y.shape[-1] == rep_y.size, f"Expected Y shape (..., {rep_y.size}), got {y.shape}" if uncentered: x_eff, y_eff = x, y else: mu_x = mean(x, rep_x) mu_y = mean(y, rep_y) x_eff = x - mu_x y_eff = y - mu_y # Empirical cross-covariance (centered or uncentered) in original coordinates: [D_y, D_x] Cxy_raw = torch.einsum("...y,...x->yx", y_eff, x_eff) / (x_eff.shape[0] - 1) # Orthogonal projection to Hom_G(rep_x, rep_y), preserving shape [D_y, D_x]. Cxy_proj = equiv_orthogonal_projection(W=Cxy_raw, rep_x=rep_x, rep_y=rep_y) return Cxy_proj