Source code for symm_learning.linalg

"""Symmetric Learning - Linear Algebra Utilities.

Utility functions for linear algebra operations on symmetric vector spaces with known
group representations.

Functions
---------
lstsq
    Least squares solution constrained to equivariant linear maps.
invariant_orthogonal_projector
    Orthogonal projection onto the G-invariant subspace.
irrep_radii
    Compute Euclidean radius of each irreducible subspace.
isotypic_signal2irreducible_subspaces
    Flatten isotypic signals into irreducible subspace components.
"""

import numpy as np
import torch
from escnn.group import Representation


[docs] def isotypic_signal2irreducible_subspaces(x: torch.Tensor, rep_x: Representation): r"""Flatten an isotypic signal into its irreducible-subspace coordinates. This function assumes :math:`\mathcal{X}` is a single isotypic subspace of type :math:`k`, i.e. .. math:: \rho_{\mathcal{X}} = \bigoplus_{i\in[1,n_k]} \hat{\rho}_k. For an input :math:`\mathbf{x}` of shape :math:`(n, n_k \cdot d_k)`, where :math:`d_k=\dim(\hat{\rho}_k)`, the output rearranges coordinates to shape :math:`(n \cdot d_k, n_k)` so each column stores one irrep copy across the sample axis. .. math:: \mathbf{z}[:, i] = [x_{1,i,1}, \ldots, x_{1,i,d_k}, x_{2,i,1}, \ldots, x_{n,i,d_k}]^\top. Args: x (:class:`~torch.Tensor`): Shape :math:`(n, n_k \cdot d_k)`. rep_x (:class:`~escnn.group.Representation`): Representation in isotypic basis with a single active irrep type. Returns: :class:`~torch.Tensor`: Flattened irreducible-subspace signal of shape :math:`(n \cdot d_k, n_k)`. Shape: :math:`(n \cdot d_k, n_k)`. """ assert len(rep_x._irreps_multiplicities) == 1, "Random variable is assumed to be in a single isotypic subspace." irrep_id = rep_x.irreps[0] irrep_dim = rep_x.group.irrep(*irrep_id).size mk = rep_x._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in X Z = x.view(-1, mk, irrep_dim).permute(0, 2, 1).reshape(-1, mk) assert Z.shape == (x.shape[0] * irrep_dim, mk) return Z
[docs] def irrep_radii(x: torch.Tensor, rep: Representation) -> torch.Tensor: r"""Compute Euclidean radii for all irreducible-subspace features. Let :math:`\rho_{\mathcal{X}}` be the (possibly decomposable) representation of a vector space :math:`\mathcal{X}`: .. math:: \rho_{\mathcal{X}} = \mathbf{Q}\left( \bigoplus_{k\in[1,n_{\text{iso}}]} \bigoplus_{i\in[1,n_k]} \hat{\rho}_k \right)\mathbf{Q}^T. We first change to the irrep-spectral basis induced by this :ref:`isotypic decomposition <isotypic-decomposition-example>` (as returned by :func:`~symm_learning.representation_theory.isotypic_decomp_rep`), :math:`\hat{\mathbf{x}}=\mathbf{Q}^T\mathbf{x}`, and then compute the radius of each irrep copy: .. math:: r_{k,i} = \lVert \hat{\mathbf{x}}_{k,i} \rVert_2. Args: x: (:class:`~torch.Tensor`) of shape :math:`(..., D)` describing vectors transforming according to ``rep``. rep: (:class:`~escnn.group.Representation`) acting on the last dimension of ``x``. Returns: (:class:`~torch.Tensor`): Radii of shape :math:`(..., N)` where :math:`N=\texttt{len(rep.irreps)}`. The output order follows ``rep.irreps`` (one radius per irreducible copy in the decomposition). Shape: - **Input** ``x``: :math:`(..., D)` with :math:`D=\dim(\rho_{\mathcal{X}})`. - **Output**: :math:`(..., N)` containing the per-irrep Euclidean norms. Note: For repeated calls with the same representation object ``rep``, the matrix :math:`\mathbf{Q}^{-1}` is cached in ``rep.attributes["Q_inv"]`` and reused. """ if x.shape[-1] != rep.size: raise ValueError(f"Expected last dimension {rep.size}, got {x.shape[-1]}") if "Q_inv" not in rep.attributes: # cache inverse change-of-basis for reuse Q_inv = torch.tensor(rep.change_of_basis_inv, device=x.device, dtype=x.dtype) rep.attributes["Q_inv"] = Q_inv else: Q_inv = rep.attributes["Q_inv"].to(x.device, x.dtype) # Change to irrep-spectral basis x_spectral = torch.einsum("ij,...j->...i", Q_inv, x) # Compute a mask for each irreducible subspace (subspace associated to an irrep) n_subspaces = len(rep.irreps) subspace_ids = _get_irrep_subspace_index(rep).to(x.device) flat = x_spectral.reshape(-1, rep.size) flat_sq = flat.pow(2) # squared magnitudes per coordinate scatter_idx = subspace_ids.view(1, -1).expand(flat_sq.size(0), -1) # broadcast labels across batch radii_sq = torch.zeros(flat_sq.size(0), n_subspaces, device=x.device, dtype=x.dtype) radii_sq.scatter_add_(1, scatter_idx, flat_sq) # sum squares inside each irrep block radii = radii_sq.sqrt().reshape(*x_spectral.shape[:-1], n_subspaces) return radii
[docs] def lstsq(x: torch.Tensor, y: torch.Tensor, rep_x: Representation, rep_y: Representation): r"""Computes a solution to the least squares problem of a system of linear equations with equivariance constraints. The :math:`\mathbb{G}`-equivariant least squares problem to the linear system of equations :math:`\mathbf{Y} = \mathbf{A}\,\mathbf{X}`, is defined as: .. math:: \begin{align} &\operatorname{argmin}_{\mathbf{A}} \| \mathbf{Y} - \mathbf{A}\,\mathbf{X} \|_F \\ & \text{s.t.} \quad \rho_{\mathcal{Y}}(g) \mathbf{A} = \mathbf{A}\rho_{\mathcal{X}}(g) \quad \forall g \in \mathbb{G}, \end{align} where :math:`\mathbf{X}: \Omega \to \mathcal{X}` and :math:`\mathbf{Y}: \Omega \to \mathcal{Y}` are random variables taking values in representation spaces :math:`\mathcal{X}` and :math:`\mathcal{Y}`, and :math:`\rho_{\mathcal{X}}`, :math:`\rho_{\mathcal{Y}}` are the corresponding (possibly decomposable) representations of :math:`\mathbb{G}`. Args: x (:class:`~torch.Tensor`): Realizations of the random variable :math:`\mathbf{X}` with shape :math:`(N, D_x)`, where :math:`N` is the number of samples. y (:class:`~torch.Tensor`): Realizations of the random variable :math:`\mathbf{Y}` with shape :math:`(N, D_y)`. rep_x (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{X}}` acting on the vector space :math:`\mathcal{X}`. rep_y (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{Y}}` acting on the vector space :math:`\mathcal{Y}`. Returns: :class:`~torch.Tensor`: A :math:`(D_y \times D_x)` matrix :math:`\mathbf{A}` satisfying the :math:`\mathbb{G}`-equivariance constraint and minimizing :math:`\|\mathbf{Y} - \mathbf{A}\,\mathbf{X}\|^2`. Shape: - X: :math:`(N, D_x)` - Y: :math:`(N, D_y)` - Output: :math:`(D_y, D_x)` Note: This function calls :func:`~symm_learning.representation_theory.isotypic_decomp_rep`, which caches decompositions in the group representation registry. Repeated calls with the same representations reuse cached decompositions. """ from symm_learning.representation_theory import isotypic_decomp_rep # assert X.shape[0] == Y.shape[0], "Expected equal number of samples in X and Y" assert x.shape[1] == rep_x.size, f"Expected X shape (N, {rep_x.size}), got {x.shape}" assert y.shape[1] == rep_y.size, f"Expected Y shape (N, {rep_y.size}), got {y.shape}" assert x.shape[-1] == rep_x.size, f"Expected X shape (..., {rep_x.size}), got {x.shape}" assert y.shape[-1] == rep_y.size, f"Expected Y shape (..., {rep_y.size}), got {y.shape}" rep_X_iso = isotypic_decomp_rep(rep_x) rep_Y_iso = isotypic_decomp_rep(rep_y) # Changes of basis from the Disentangled/Isotypic-basis of X, and Y to the original basis. Qx = torch.tensor(rep_X_iso.change_of_basis, device=x.device, dtype=x.dtype) Qy = torch.tensor(rep_Y_iso.change_of_basis, device=y.device, dtype=y.dtype) rep_X_iso_subspaces = rep_X_iso.attributes["isotypic_reps"] rep_Y_iso_subspaces = rep_Y_iso.attributes["isotypic_reps"] # Get the dimensions of the isotypic subspaces of the same type in the input/output representations. iso_idx_X, iso_idx_Y = {}, {} x_dim = 0 for iso_id, rep_k in rep_X_iso_subspaces.items(): iso_idx_X[iso_id] = slice(x_dim, x_dim + rep_k.size) x_dim += rep_k.size y_dim = 0 for iso_id, rep_k in rep_Y_iso_subspaces.items(): iso_idx_Y[iso_id] = slice(y_dim, y_dim + rep_k.size) y_dim += rep_k.size x_iso = torch.einsum("ij,...j->...i", Qx.T, x) y_iso = torch.einsum("ij,...j->...i", Qy.T, y) A_iso = torch.zeros((rep_y.size, rep_x.size), dtype=x.dtype, device=x.device) for iso_id in rep_Y_iso_subspaces: if iso_id not in rep_X_iso_subspaces: continue # No covariance between the isotypic subspaces of different types. x_k = x_iso[..., iso_idx_X[iso_id]] y_k = y_iso[..., iso_idx_Y[iso_id]] rep_X_k = rep_X_iso_subspaces[iso_id] rep_Y_k = rep_Y_iso_subspaces[iso_id] # Compute empirical least-squares. A_k_emp = torch.linalg.lstsq(x_k, y_k).solution.T A_k = _project_to_irrep_endomorphism_basis(A_k_emp, rep_X_k, rep_Y_k) A_iso[iso_idx_Y[iso_id], iso_idx_X[iso_id]] = A_k # Change back to the original input output basis sets A = Qy @ A_iso @ Qx.T return A
[docs] def invariant_orthogonal_projector(rep_x: Representation) -> torch.Tensor: r"""Computes the orthogonal projection to the invariant subspace. The input representation :math:`\rho_{\mathcal{X}}: \mathbb{G} \mapsto \mathbb{G}\mathbb{L}(\mathcal{X})` is transformed to the spectral basis given by: .. math:: \rho_{\mathcal{X}} = \mathbf{Q}\left( \bigoplus_{k\in[1,n_{\text{iso}}]} \bigoplus_{i\in[1,n_k]} \hat{\rho}_k \right)\mathbf{Q}^T where :math:`\hat{\rho}_k` are irreducible representations of :math:`\mathbb{G}`, :math:`n_k` is the multiplicity of type :math:`k`, and :math:`\mathbf{Q}: \mathcal{X}\to\mathcal{X}` is the orthogonal change of basis from the irrep-spectral basis to the original basis. Define the diagonal selector :math:`\mathbf{S}\in\mathbb{R}^{D\times D}` in irrep-spectral coordinates by .. math:: S_{jj} = \begin{cases} 1, & \text{if coordinate } j \text{ belongs to a trivial irrep copy}, \\ 0, & \text{otherwise}. \end{cases} Then the orthogonal projector onto the invariant subspace :math:`\mathcal{X}^{\text{inv}}=\{\mathbf{x}\in\mathcal{X}: \rho_{\mathcal{X}}(g)\mathbf{x}=\mathbf{x}, \forall g\in\mathbb{G}\}` is .. math:: \mathbf{P}_{\mathrm{inv}} = \mathbf{Q}\,\mathbf{S}\,\mathbf{Q}^T. This projector enforces the invariance constraint: .. math:: \rho_{\mathcal{X}}(g)\,\mathbf{P}_{\mathrm{inv}} = \mathbf{P}_{\mathrm{inv}}\,\rho_{\mathcal{X}}(g) = \mathbf{P}_{\mathrm{inv}} \quad \forall g\in\mathbb{G}. Args: rep_x (:class:`~escnn.group.Representation`): The representation for which the orthogonal projection to the invariant subspace is computed. Returns: :class:`~torch.Tensor`: The orthogonal projection matrix to the invariant subspace, :math:`\mathbf{Q} \mathbf{S} \mathbf{Q}^T`. """ Qx_T = torch.as_tensor(rep_x.change_of_basis_inv) Qx = torch.as_tensor(rep_x.change_of_basis) # S is an indicator of which dimension (in the irrep-spectral basis) is associated with a trivial irrep S = torch.zeros((rep_x.size, rep_x.size)) irreps_dimension = [] cum_dim = 0 for irrep_id in rep_x.irreps: irrep = rep_x.group.irrep(*irrep_id) # Get dimensions of the irrep in the original basis irrep_dims = range(cum_dim, cum_dim + irrep.size) irreps_dimension.append(irrep_dims) if irrep_id == rep_x.group.trivial_representation.id: # this dimension is associated with a trivial irrep S[irrep_dims, irrep_dims] = 1 cum_dim += irrep.size inv_projector = Qx @ S @ Qx_T return inv_projector
def _project_to_irrep_endomorphism_basis( A: torch.Tensor, rep_x: Representation, rep_y: Representation, ) -> torch.Tensor: r"""Projects a linear map A: X -> Y between two isotypic spaces to the space of equivariant linear maps. Given a linear map :math:`A: X \to Y`, where :math:`X` and :math:`Y` are isotypic vector spaces of the same type, that is, their representations are built from :math:`m_x` and :math:`m_y` copies of the same irrep, this function projects the linear map to the space of equivariant linear maps between the two isotypic spaces. Args: A (:class:`~torch.Tensor`): The linear map to be projected, of shape :math:`(m_y \cdot d, m_x \cdot d)`, where :math:`d` is the dimension of the irreducible representation, and :math:`m_x` and :math:`m_y` are the multiplicities of the irreducible representation in :math:`X` and :math:`Y`, respectively. rep_x (:class:`~escnn.group.Representation`): The representation of the isotypic space :math:`X`. rep_y (:class:`~escnn.group.Representation`): The representation of the isotypic space :math:`Y`. Returns: A_equiv (:class:`~torch.Tensor`): A projected linear map of shape :math:`(m_y \cdot d, m_x \cdot d)` which commutes with the action of the group on the isotypic spaces :math:`X` and :math:`Y`. That is: :math:`A_{equiv} \circ \rho_X(g) = \rho_Y(g) \circ A_{equiv}` for all :math:`g \in \mathbb{G}`. """ irrep_id = rep_x.irreps[0] irrep = rep_x.group.irrep(*irrep_id) assert A.shape == (rep_y.size, rep_x.size), "Expected A: X -> Y" assert len(rep_x._irreps_multiplicities) == 1, f"Expected rep with a single irrep type, got {rep_x.irreps}" assert len(rep_y._irreps_multiplicities) == 1, f"Expected rep with a single irrep type, got {rep_y.irreps}" assert irrep_id == rep_y.irreps[0], f"Irreps {irrep_id} != {rep_y.irreps[0]}. Hence A=0" x_in_iso_basis = np.allclose(rep_x.change_of_basis_inv, np.eye(rep_x.size), atol=1e-6, rtol=1e-4) assert x_in_iso_basis, "Expected X to be in spectral/isotypic basis" y_in_iso_basis = np.allclose(rep_y.change_of_basis_inv, np.eye(rep_y.size), atol=1e-6, rtol=1e-4) assert y_in_iso_basis, "Expected Y to be in spectral/isotypic basis" m_x = rep_x._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in X m_y = rep_y._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in Y # Get the basis of endomorphisms of the irrep (B, d, d) B = 1 | 2 | 4 irrep_end_basis = torch.tensor(irrep.endomorphism_basis(), device=A.device, dtype=A.dtype) A_irreps = A.view(m_y, irrep.size, m_x, irrep.size).permute(0, 2, 1, 3).contiguous() # Compute basis expansion coefficients of each irrep cross-covariance in basis of End(irrep) ======== # Frobenius inner product <C , Ψ_b> = Σ_{i,j} C_{ij} Ψ_b,ij A_irreps_basis_coeff = torch.einsum("mnij,bij->mnb", A_irreps, irrep_end_basis) # (m_y , m_x , B) # squared norms ‖Ψ_b‖² (only once, very small) basis_coeff_norms = torch.einsum("bij,bij->b", irrep_end_basis, irrep_end_basis) # (B,) A_irreps_basis_coeff = A_irreps_basis_coeff / basis_coeff_norms[None, None] A_irreps = torch.einsum("...b,bij->...ij", A_irreps_basis_coeff, irrep_end_basis) # (m_y , m_x , d , d) # Reshape to (my * d, mx * d) A_equiv = A_irreps.permute(0, 2, 1, 3).reshape(m_y * irrep.size, m_x * irrep.size) return A_equiv def _get_irrep_subspace_index(rep: Representation): labels = torch.empty(rep.size, dtype=torch.long) irreps_dims = {id: rep.group.irrep(*id).size for id in rep._irreps_multiplicities} pos = 0 for count, irrep_id in enumerate(rep.irreps): labels[pos : pos + irreps_dims[irrep_id]] = count # fill contiguous block for this copy pos += irreps_dims[irrep_id] return labels