Source code for symm_learning.linalg

"""Symmetric Learning - Linear Algebra Utilities.

Utility functions for linear algebra operations on symmetric vector spaces with known
group representations.

Functions
---------
lstsq
    Least squares solution constrained to equivariant linear maps.
invariant_orthogonal_projector
    Orthogonal projection onto the G-invariant subspace.
equiv_orthogonal_projection_coefficients
    Orthogonal projection onto Hom_G returned in the flattened homomorphism basis.
equiv_linear_map
    Construct a dense equivariant linear map from flattened homomorphism-basis coefficients.
equiv_orthogonal_projection
    Orthogonal projection onto Hom_G using precomputed isotypic-basis tensors.
irrep_radii
    Compute Euclidean radius of each irreducible subspace.
isotypic_signal2irreducible_subspaces
    Flatten isotypic signals into irreducible subspace components.
"""

from __future__ import annotations

from collections.abc import Mapping
from typing import NamedTuple, TypedDict

import numpy as np
import torch
from escnn.group import Representation

from symm_learning.utils import _cached_rep_matrix


class IsoSpaceProjection(TypedDict):
    """Per-irrep projection metadata returned by :func:`project_in_isobasis`."""

    coeff: torch.Tensor
    endo_basis_flat: torch.Tensor
    m_out: int
    m_in: int
    d_k: int
    out_slice: slice
    in_slice: slice


class IsotypicTensorCache(NamedTuple):
    """Reusable tensors for repeated isotypic-basis projection and synthesis calls."""

    Q_out: torch.Tensor
    Q_in_inv: torch.Tensor
    endo_basis_flat: dict[tuple[int, ...], torch.Tensor]
    endo_norm_sq: dict[tuple[int, ...], torch.Tensor] | None = None


def _validate_tensor_cache(
    tensor_cache: IsotypicTensorCache,
    rep_X_iso: Representation,
    rep_Y_iso: Representation,
    require_endo_norm_sq: bool,
) -> tuple[torch.Tensor, torch.Tensor, dict[tuple[int, ...], torch.Tensor], dict[tuple[int, ...], torch.Tensor] | None]:
    Q_out = tensor_cache.Q_out
    Q_in_inv = tensor_cache.Q_in_inv
    endo_basis_flat_cache = tensor_cache.endo_basis_flat
    endo_norm_sq_cache = tensor_cache.endo_norm_sq

    if not isinstance(Q_out, torch.Tensor) or not isinstance(Q_in_inv, torch.Tensor):
        raise TypeError("tensor_cache.Q_out and tensor_cache.Q_in_inv must be torch.Tensor instances")
    if Q_out.shape != (rep_Y_iso.size, rep_Y_iso.size):
        raise ValueError(f"Expected tensor_cache.Q_out shape {(rep_Y_iso.size, rep_Y_iso.size)}, got {Q_out.shape}")
    if Q_in_inv.shape != (rep_X_iso.size, rep_X_iso.size):
        raise ValueError(
            f"Expected tensor_cache.Q_in_inv shape {(rep_X_iso.size, rep_X_iso.size)}, got {Q_in_inv.shape}"
        )
    if not isinstance(endo_basis_flat_cache, Mapping):
        raise TypeError("tensor_cache.endo_basis_flat must be a mapping keyed by irrep id")
    if require_endo_norm_sq:
        if endo_norm_sq_cache is None:
            raise ValueError("tensor_cache.endo_norm_sq is required for this operation")
        if not isinstance(endo_norm_sq_cache, Mapping):
            raise TypeError("tensor_cache.endo_norm_sq must be a mapping keyed by irrep id")

    return Q_out, Q_in_inv, endo_basis_flat_cache, endo_norm_sq_cache


[docs] def isotypic_signal2irreducible_subspaces(x: torch.Tensor, rep_x: Representation): r"""Flatten an isotypic signal into its irreducible-subspace coordinates. This function assumes :math:`\mathcal{X}` is a single isotypic subspace of type :math:`k`, i.e. .. math:: \rho_{\mathcal{X}} = \bigoplus_{i\in[1,m_k]} \hat{\rho}_k. For an input :math:`\mathbf{x}` of shape :math:`(n, m_k \cdot d_k)`, where :math:`d_k=\dim(\hat{\rho}_k)`, the output rearranges coordinates to shape :math:`(n \cdot d_k, m_k)` so each column stores one irrep copy across the sample axis. .. math:: \mathbf{z}[:, i] = [x_{1,i,1}, \ldots, x_{1,i,d_k}, x_{2,i,1}, \ldots, x_{n,i,d_k}]^\top. Args: x (:class:`~torch.Tensor`): Shape :math:`(n, m_k \cdot d_k)`. rep_x (:class:`~escnn.group.Representation`): Representation in isotypic basis with a single active irrep type. Returns: :class:`~torch.Tensor`: Flattened irreducible-subspace signal of shape :math:`(n \cdot d_k, m_k)`. Shape: :math:`(n \cdot d_k, m_k)`. """ assert len(rep_x._irreps_multiplicities) == 1, "Random variable is assumed to be in a single isotypic subspace." irrep_id = rep_x.irreps[0] irrep_dim = rep_x.group.irrep(*irrep_id).size mk = rep_x._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in X Z = x.view(-1, mk, irrep_dim).permute(0, 2, 1).reshape(-1, mk) assert Z.shape == (x.shape[0] * irrep_dim, mk) return Z
[docs] def irrep_radii(x: torch.Tensor, rep: Representation) -> torch.Tensor: r"""Compute Euclidean radii for all irreducible-subspace features. Let :math:`\rho_{\mathcal{X}}` be the (possibly decomposable) representation of a vector space :math:`\mathcal{X}`: .. math:: \rho_{\mathcal{X}} = \mathbf{Q}\left( \bigoplus_{k\in[1,n_{\text{iso}}]} \bigoplus_{i\in[1,m_k]} \hat{\rho}_k \right)\mathbf{Q}^T. We first change to the irrep-spectral basis induced by this :ref:`isotypic decomposition <isotypic-decomposition-example>` (as returned by :func:`~symm_learning.representation_theory.isotypic_decomp_rep`), :math:`\hat{\mathbf{x}}=\mathbf{Q}^T\mathbf{x}`, and then compute the radius of each irrep copy: .. math:: r_{k,i} = \lVert \hat{\mathbf{x}}_{k,i} \rVert_2. Args: x: (:class:`~torch.Tensor`) of shape :math:`(..., D)` describing vectors transforming according to ``rep``. rep: (:class:`~escnn.group.Representation`) acting on the last dimension of ``x``. Returns: (:class:`~torch.Tensor`): Radii of shape :math:`(..., N)` where :math:`N=\texttt{len(rep.irreps)}`. The output order follows ``rep.irreps`` (one radius per irreducible copy in the decomposition). Shape: - **Input** ``x``: :math:`(..., D)` with :math:`D=\dim(\rho_{\mathcal{X}})`. - **Output**: :math:`(..., N)` containing the per-irrep Euclidean norms. Note: For repeated calls with the same representation object ``rep``, the matrix :math:`\mathbf{Q}^{-1}` is cached in ``rep.attributes["Q_inv"]`` and reused. """ if x.shape[-1] != rep.size: raise ValueError(f"Expected last dimension {rep.size}, got {x.shape[-1]}") Q_inv = _cached_rep_matrix(rep=rep, key="Q_inv", matrix=rep.change_of_basis_inv, like=x) # Change to irrep-spectral basis x_spectral = torch.einsum("ij,...j->...i", Q_inv, x) n_subspaces = len(rep.irreps) subspace_dims = [rep.group.irrep(*irrep_id).size for irrep_id in rep.irreps] flat = x_spectral.reshape(-1, rep.size) # vector_norm has a stable subgradient at zero, unlike manual sqrt(sum(x^2)) flat_blocks = torch.split(flat, subspace_dims, dim=-1) radii = torch.stack([torch.linalg.vector_norm(block, ord=2, dim=-1) for block in flat_blocks], dim=-1) radii = radii.reshape(*x_spectral.shape[:-1], n_subspaces) return radii
[docs] def lstsq(x: torch.Tensor, y: torch.Tensor, rep_x: Representation, rep_y: Representation): r"""Computes a solution to the least squares problem of a system of linear equations with equivariance constraints. The :math:`\mathbb{G}`-equivariant least squares problem to the linear system of equations :math:`\mathbf{Y} = \mathbf{A}\,\mathbf{X}`, is defined as: .. math:: \begin{align} &\operatorname{argmin}_{\mathbf{A}} \| \mathbf{Y} - \mathbf{A}\,\mathbf{X} \|_F \\ & \text{s.t.} \quad \rho_{\mathcal{Y}}(g) \mathbf{A} = \mathbf{A}\rho_{\mathcal{X}}(g) \quad \forall g \in \mathbb{G}, \end{align} where :math:`\mathbf{X}: \Omega \to \mathcal{X}` and :math:`\mathbf{Y}: \Omega \to \mathcal{Y}` are random variables taking values in representation spaces :math:`\mathcal{X}` and :math:`\mathcal{Y}`, and :math:`\rho_{\mathcal{X}}`, :math:`\rho_{\mathcal{Y}}` are the corresponding (possibly decomposable) representations of :math:`\mathbb{G}`. Args: x (:class:`~torch.Tensor`): Realizations of the random variable :math:`\mathbf{X}` with shape :math:`(N, D_x)`, where :math:`N` is the number of samples. y (:class:`~torch.Tensor`): Realizations of the random variable :math:`\mathbf{Y}` with shape :math:`(N, D_y)`. rep_x (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{X}}` acting on the vector space :math:`\mathcal{X}`. rep_y (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{Y}}` acting on the vector space :math:`\mathcal{Y}`. Returns: :class:`~torch.Tensor`: A :math:`(D_y \times D_x)` matrix :math:`\mathbf{A}` satisfying the :math:`\mathbb{G}`-equivariance constraint and minimizing :math:`\|\mathbf{Y} - \mathbf{A}\,\mathbf{X}\|^2`. Shape: - X: :math:`(N, D_x)` - Y: :math:`(N, D_y)` - Output: :math:`(D_y, D_x)` Note: This function calls :func:`~symm_learning.representation_theory.isotypic_decomp_rep`, which caches decompositions in the group representation registry. Repeated calls with the same representations reuse cached decompositions. """ from symm_learning.representation_theory import isotypic_decomp_rep # assert X.shape[0] == Y.shape[0], "Expected equal number of samples in X and Y" assert x.shape[1] == rep_x.size, f"Expected X shape (N, {rep_x.size}), got {x.shape}" assert y.shape[1] == rep_y.size, f"Expected Y shape (N, {rep_y.size}), got {y.shape}" assert x.shape[-1] == rep_x.size, f"Expected X shape (..., {rep_x.size}), got {x.shape}" assert y.shape[-1] == rep_y.size, f"Expected Y shape (..., {rep_y.size}), got {y.shape}" if x.device != y.device: raise ValueError(f"Expected x and y on same device, got {x.device} and {y.device}") work_dtype = torch.promote_types(x.dtype, y.dtype) x_work = x if x.dtype == work_dtype else x.to(dtype=work_dtype) y_work = y if y.dtype == work_dtype else y.to(dtype=work_dtype) rep_X_iso = isotypic_decomp_rep(rep_x) rep_Y_iso = isotypic_decomp_rep(rep_y) # Changes of basis from the Disentangled/Isotypic-basis of X, and Y to the original basis. Qx = _cached_rep_matrix(rep=rep_X_iso, key="Q", matrix=rep_X_iso.change_of_basis, like=x_work) Qy = _cached_rep_matrix(rep=rep_Y_iso, key="Q", matrix=rep_Y_iso.change_of_basis, like=x_work) rep_X_iso_subspaces = rep_X_iso.attributes["isotypic_reps"] rep_Y_iso_subspaces = rep_Y_iso.attributes["isotypic_reps"] # Get the dimensions of the isotypic subspaces of the same type in the input/output representations. iso_idx_X, iso_idx_Y = {}, {} x_dim = 0 for iso_id, rep_k in rep_X_iso_subspaces.items(): iso_idx_X[iso_id] = slice(x_dim, x_dim + rep_k.size) x_dim += rep_k.size y_dim = 0 for iso_id, rep_k in rep_Y_iso_subspaces.items(): iso_idx_Y[iso_id] = slice(y_dim, y_dim + rep_k.size) y_dim += rep_k.size x_iso = torch.einsum("ij,...j->...i", Qx.T, x_work) y_iso = torch.einsum("ij,...j->...i", Qy.T, y_work) A_iso = torch.zeros((rep_y.size, rep_x.size), dtype=work_dtype, device=x_work.device) for iso_id in rep_Y_iso_subspaces: if iso_id not in rep_X_iso_subspaces: continue # No covariance between the isotypic subspaces of different types. x_k = x_iso[..., iso_idx_X[iso_id]] y_k = y_iso[..., iso_idx_Y[iso_id]] rep_X_k = rep_X_iso_subspaces[iso_id] rep_Y_k = rep_Y_iso_subspaces[iso_id] # Compute empirical least-squares. A_k_emp = torch.linalg.lstsq(x_k, y_k).solution.T A_k = _project_to_irrep_endomorphism_basis(A_k_emp, rep_X_k, rep_Y_k) A_iso[iso_idx_Y[iso_id], iso_idx_X[iso_id]] = A_k # Change back to the original input output basis sets A = Qy @ A_iso @ Qx.T return A
[docs] def invariant_orthogonal_projector( rep_x: Representation, device: torch.device | str | None = None, dtype: torch.dtype | None = None ) -> torch.Tensor: r"""Computes the orthogonal projection to the invariant subspace. The input representation :math:`\rho_{\mathcal{X}}: \mathbb{G} \mapsto \mathbb{G}\mathbb{L}(\mathcal{X})` is transformed to the spectral basis given by: .. math:: \rho_{\mathcal{X}} = \mathbf{Q}\left( \bigoplus_{k\in[1,n_{\text{iso}}]} \bigoplus_{i\in[1,m_k]} \hat{\rho}_k \right)\mathbf{Q}^T where :math:`\hat{\rho}_k` are irreducible representations of :math:`\mathbb{G}`, :math:`m_k` is the multiplicity of type :math:`k`, and :math:`\mathbf{Q}: \mathcal{X}\to\mathcal{X}` is the orthogonal change of basis from the irrep-spectral basis to the original basis. Define the diagonal selector :math:`\mathbf{S}\in\mathbb{R}^{D\times D}` in irrep-spectral coordinates by .. math:: S_{jj} = \begin{cases} 1, & \text{if coordinate } j \text{ belongs to a trivial irrep copy}, \\ 0, & \text{otherwise}. \end{cases} Then the orthogonal projector onto the invariant subspace :math:`\mathcal{X}^{\text{inv}}=\{\mathbf{x}\in\mathcal{X}: \rho_{\mathcal{X}}(g)\mathbf{x}=\mathbf{x}, \forall g\in\mathbb{G}\}` is .. math:: \mathbf{P}_{\mathrm{inv}} = \mathbf{Q}\,\mathbf{S}\,\mathbf{Q}^T. This projector enforces the invariance constraint: .. math:: \rho_{\mathcal{X}}(g)\,\mathbf{P}_{\mathrm{inv}} = \mathbf{P}_{\mathrm{inv}}\,\rho_{\mathcal{X}}(g) = \mathbf{P}_{\mathrm{inv}} \quad \forall g\in\mathbb{G}. Args: rep_x (:class:`~escnn.group.Representation`): The representation for which the orthogonal projection to the invariant subspace is computed. device (:class:`~torch.device`, optional): Device for the returned projector. If ``None``, uses CPU. dtype (:class:`~torch.dtype`, optional): Data type for the returned projector. If ``None``, uses ``torch.get_default_dtype()``. Returns: :class:`~torch.Tensor`: The orthogonal projection matrix to the invariant subspace, :math:`\mathbf{Q} \mathbf{S} \mathbf{Q}^T`. """ device = torch.device("cpu") if device is None else torch.device(device) dtype = torch.get_default_dtype() if dtype is None else dtype Qx_T = torch.as_tensor(rep_x.change_of_basis_inv, device=device, dtype=dtype) Qx = torch.as_tensor(rep_x.change_of_basis, device=device, dtype=dtype) # S is an indicator of which dimension (in the irrep-spectral basis) is associated with a trivial irrep S = torch.zeros((rep_x.size, rep_x.size), device=device, dtype=dtype) irreps_dimension = [] cum_dim = 0 for irrep_id in rep_x.irreps: irrep = rep_x.group.irrep(*irrep_id) # Get dimensions of the irrep in the original basis irrep_dims = range(cum_dim, cum_dim + irrep.size) irreps_dimension.append(irrep_dims) if irrep_id == rep_x.group.trivial_representation.id: # this dimension is associated with a trivial irrep S[irrep_dims, irrep_dims] = 1 cum_dim += irrep.size inv_projector = Qx @ S @ Qx_T return inv_projector
[docs] def equiv_orthogonal_projection( W: torch.Tensor, rep_x: Representation, rep_y: Representation, tensor_cache: IsotypicTensorCache | None = None, ) -> torch.Tensor: r"""Orthogonally project a linear map onto :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})`. Let :math:`(\mathcal{X}, \rho_{\mathcal{X}})` and :math:`(\mathcal{Y}, \rho_{\mathcal{Y}})` be two :math:`\mathbb{G}`-symmetric vector spaces. Given any dense linear map :math:`\mathbf{W}\in\mathbb{R}^{\dim(\mathcal{Y})\times\dim(\mathcal{X})}`, this function returns its Frobenius-orthogonal projection onto :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})`: .. math:: \Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{W}) = \operatorname*{argmin}_{\mathbf{A}\in\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})} \|\mathbf{W}-\mathbf{A}\|_F. The computation uses the isotypic decomposition of both representations and projects each shared-irrep block independently. This projection is equivalent to the Reynolds/group-average operator, but more computational and memory efficient when the order of the group is large. .. math:: \Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{A}) = \frac{1}{|\mathbb{G}|}\sum_{g\in\mathbb{G}} \rho_{\mathcal{Y}}(g)\,\mathbf{A}\,\rho_{\mathcal{X}}(g^{-1}). Args: W (:class:`~torch.Tensor`): Dense map (or batch of maps) of shape :math:`(..., D_y, D_x)`. rep_x (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\mathcal{X}}`. rep_y (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\mathcal{Y}}`. tensor_cache (:class:`IsotypicTensorCache`, optional): Optional override containing the tensor cache required by :func:`project_in_isobasis`. When provided, all required tensors must be present. Returns: :class:`~torch.Tensor`: Projected map(s) with same shape, dtype, and device as ``W``. Shape: - **W**: :math:`(..., D_y, D_x)`. - **Output**: :math:`(..., D_y, D_x)`. """ Q_out, Q_in_inv, projection_iso_spaces = project_in_isobasis( W=W, rep_x=rep_x, rep_y=rep_y, tensor_cache=tensor_cache ) leading_shape = W.shape[:-2] batch_axes = tuple(range(len(leading_shape))) W_iso = W.new_zeros(*leading_shape, rep_y.size, rep_x.size) for iso_space in projection_iso_spaces.values(): coeff = iso_space["coeff"] endo_basis_flat = iso_space["endo_basis_flat"] m_out = iso_space["m_out"] m_in = iso_space["m_in"] d_k = iso_space["d_k"] out_slice = iso_space["out_slice"] in_slice = iso_space["in_slice"] block_proj_flat = coeff.matmul(endo_basis_flat) # Undo flattening and permutation to recover block matrix in isotypic coordinates: # [..., m_out, m_in, d_k, d_k] -> [..., m_out*d_k, m_in*d_k] block_proj = block_proj_flat.view(*leading_shape, m_out, m_in, d_k, d_k) block_proj = block_proj.permute(*batch_axes, -4, -2, -3, -1) block_proj = block_proj.reshape(*leading_shape, m_out * d_k, m_in * d_k) W_iso[..., out_slice, in_slice] = block_proj # Return projected operator in original coordinates. W_proj = (Q_out @ W_iso) @ Q_in_inv return W_proj
[docs] def equiv_orthogonal_projection_coefficients( W: torch.Tensor, rep_x: Representation, rep_y: Representation, tensor_cache: IsotypicTensorCache | None = None, ) -> torch.Tensor: r"""Return the flattened homomorphism-basis coefficients of :math:`\Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{W})`. Let :math:`\mathbf{W}\in\mathbb{R}^{D_y\times D_x}` be any dense linear map between :math:`(\mathcal{X}, \rho_{\mathcal{X}})` and :math:`(\mathcal{Y}, \rho_{\mathcal{Y}})`. This function computes the orthogonal projection .. math:: \Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{W}) \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}}) and returns its coefficients in the blockwise isotypic basis described in :ref:`Leveraging the structure of Equivariant Linear maps <equivariant-linear-maps-example>`. In isotypic coordinates, the projected operator decomposes as .. math:: \Pi_{\mathrm{Hom}_{\mathbb{G}}}(\mathbf{W}) = \mathbf{Q}_{\mathcal{Y}} \left( \bigoplus_{k\in[1,n_{\text{iso}}]} \mathbf{W}^{(k)} \right) \mathbf{Q}_{\mathcal{X}}^T. For each common irrep type :math:`k`, the corresponding block is written as .. math:: \mathbf{W}^{(k)} = \sum_{s=1}^{S_k}\mathbf{\Theta}^{(k)}_s \otimes \mathbf{\Psi}^{(k)}_s, where :math:`\mathbf{\Theta}^{(k)}_s \in \mathbb{R}^{m_k^{\mathcal{Y}} \times m_k^{\mathcal{X}}}`, and :math:`\{\mathbf{\Psi}^{(k)}_s\}_{s=1}^{S_k}` is a basis of :math:`\mathrm{End}_{\mathbb{G}}(\hat{\rho}_k)`. The output concatenates all coefficient blocks after flattening each tensor of shape :math:`(m_k^{\mathcal{Y}} m_k^{\mathcal{X}}, S_k)`, with the endomorphism-basis index varying fastest. Args: W (:class:`~torch.Tensor`): Dense map (or batch of maps) of shape :math:`(..., D_y, D_x)`. rep_x (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\mathcal{X}}`. rep_y (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\mathcal{Y}}`. tensor_cache (:class:`IsotypicTensorCache`, optional): Optional override containing the tensor cache required by :func:`project_in_isobasis`. When provided, all required tensors must be present. Returns: :class:`~torch.Tensor`: Flattened coefficient vector(s) of shape :math:`(..., \dim(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})))`. Shape: - **W**: :math:`(..., D_y, D_x)`. - **Output**: :math:`(..., |\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})|)`. """ _, _, projection_iso_spaces = project_in_isobasis(W=W, rep_x=rep_x, rep_y=rep_y, tensor_cache=tensor_cache) leading_shape = W.shape[:-2] if not projection_iso_spaces: return W.new_zeros(*leading_shape, 0) coeff = [iso_space["coeff"].reshape(*leading_shape, -1) for iso_space in projection_iso_spaces.values()] return torch.cat(coeff, dim=-1)
def equiv_linear_map( w_dof: torch.Tensor, rep_x: Representation, rep_y: Representation, tensor_cache: IsotypicTensorCache | None = None, ) -> torch.Tensor: r"""Expand flattened homomorphism-basis coefficients into a dense equivariant linear map. Let :math:`\boldsymbol{\theta}` denote flattened coefficients in the blockwise basis of :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})` used by :func:`equiv_orthogonal_projection_coefficients`. In isotypic coordinates, the resulting equivariant map has the form .. math:: \mathbf{W} = \mathbf{Q}_{\mathcal{Y}} \left( \bigoplus_{k\in[1,n_{\text{iso}}]} \mathbf{W}^{(k)} \right) \mathbf{Q}_{\mathcal{X}}^T. For each common irrep type :math:`k`, the corresponding block is synthesized as .. math:: \mathbf{W}^{(k)} = \sum_{s=1}^{S_k}\mathbf{\Theta}^{(k)}_s \otimes \mathbf{\Psi}^{(k)}_s, where :math:`\mathbf{\Theta}^{(k)}_s \in \mathbb{R}^{m_k^{\mathcal{Y}} \times m_k^{\mathcal{X}}}` and :math:`\{\mathbf{\Psi}^{(k)}_s\}_{s=1}^{S_k}` is a basis of :math:`\mathrm{End}_{\mathbb{G}}(\hat{\rho}_k)`. The flattened input ordering matches :func:`equiv_orthogonal_projection_coefficients`: within each common irrep type :math:`k`, the endomorphism-basis index :math:`s` varies fastest, then the input multiplicity index, and finally the output multiplicity index. Args: w_dof (:class:`~torch.Tensor`): Flattened homomorphism-basis coefficients of shape :math:`(..., \dim(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})))`. rep_x (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\mathcal{X}}`. rep_y (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\mathcal{Y}}`. tensor_cache (:class:`IsotypicTensorCache`, optional): Optional tensor cache override containing ``Q_out``, ``Q_in_inv``, and ``endo_basis_flat`` keyed by irrep id. When provided, all required tensors must be present. Returns: :class:`~torch.Tensor`: Dense equivariant linear map(s) of shape :math:`(..., D_y, D_x)`. Shape: - **w_dof**: :math:`(..., |\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})|)`. - **Output**: :math:`(..., D_y, D_x)`. Note: Repeated calls with the same representations reuse cached isotypic decompositions, change-of-basis matrices, and flattened irrep endomorphism bases. """ if w_dof.ndim == 0: raise ValueError("Expected w_dof to have at least one dimension") rep_X_iso, rep_Y_iso, hom_iso_spaces, hom_dim = _hom_irrep_iso_spaces(rep_x=rep_x, rep_y=rep_y) if w_dof.shape[-1] != hom_dim: raise ValueError(f"Expected w_dof shape (..., {hom_dim}), got {tuple(w_dof.shape)}") if tensor_cache is None: Q_out = _cached_rep_matrix(rep=rep_Y_iso, key="Q", matrix=rep_Y_iso.change_of_basis, like=w_dof) Q_in_inv = _cached_rep_matrix(rep=rep_X_iso, key="Q_inv", matrix=rep_X_iso.change_of_basis_inv, like=w_dof) endo_basis_flat_cache = None else: Q_out, Q_in_inv, endo_basis_flat_cache, _ = _validate_tensor_cache( tensor_cache=tensor_cache, rep_X_iso=rep_X_iso, rep_Y_iso=rep_Y_iso, require_endo_norm_sq=False, ) leading_shape = w_dof.shape[:-1] batch_axes = tuple(range(len(leading_shape))) W_iso = w_dof.new_zeros(*leading_shape, rep_Y_iso.size, rep_X_iso.size) for iso_space in hom_iso_spaces: irrep_id = iso_space["irrep_id"] if tensor_cache is None: irrep = rep_X_iso.group.irrep(*irrep_id) endo_basis_flat, _ = _cached_irrep_endomorphism_basis(irrep=irrep, like=w_dof) else: if irrep_id not in endo_basis_flat_cache: raise ValueError(f"tensor_cache['endo_basis_flat'] missing tensor for irrep {irrep_id}") endo_basis_flat = endo_basis_flat_cache[irrep_id] if not isinstance(endo_basis_flat, torch.Tensor): raise TypeError(f"tensor_cache['endo_basis_flat'][{irrep_id}] must be a torch.Tensor") if endo_basis_flat.ndim != 2 or endo_basis_flat.shape[1] != iso_space["d_k"] * iso_space["d_k"]: raise ValueError( "Expected tensor_cache['endo_basis_flat'][" f"{irrep_id}] shape (S_k, {iso_space['d_k'] * iso_space['d_k']}), " f"got {tuple(endo_basis_flat.shape)}" ) m_out = iso_space["m_out"] m_in = iso_space["m_in"] d_k = iso_space["d_k"] out_slice = iso_space["out_slice"] in_slice = iso_space["in_slice"] hom_basis_slice = iso_space["hom_basis_slice"] theta_k = w_dof[..., hom_basis_slice].view(*leading_shape, m_out * m_in, endo_basis_flat.size(0)) block_flat = theta_k.matmul(endo_basis_flat) block = block_flat.view(*leading_shape, m_out, m_in, d_k, d_k) block = block.permute(*batch_axes, -4, -2, -3, -1) block = block.reshape(*leading_shape, m_out * d_k, m_in * d_k) W_iso[..., out_slice, in_slice] = block return (Q_out @ W_iso) @ Q_in_inv
[docs] def project_in_isobasis( W: torch.Tensor, rep_x: Representation, rep_y: Representation, tensor_cache: IsotypicTensorCache | None = None, ) -> tuple[torch.Tensor, torch.Tensor, dict[tuple[int, ...], IsoSpaceProjection]]: r"""Project W onto :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})` in isotypic coordinates This is an utility function handling the projection of a dense linear map :math:`\mathbf{W}:\mathcal{X}\to\mathcal{Y}` onto the space of :math:`\mathbb{G}`-equivariant linear maps, :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})`. Returning useful intermediate tensors, for algebraic and statistical computations. Let :math:`\mathbf{Q}_{\mathcal{X}}` and :math:`\mathbf{Q}_{\mathcal{Y}}` be the orthogonal change-of-basis exposing the isotypic decomposition of the symmetric vector spaces :math:`(\mathcal{X}, \rho_{\mathcal{X}})` and :math:`(\mathcal{Y}, \rho_{\mathcal{Y}})` (see :ref:`Isotypic Decomposition <isotypic-decomposition-example>`). In this basis, any :math:`\mathbb{G}`-equivariant linear map :math:`\mathbf{A} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})` decomposes in block diagonal form as described in :ref:`Leveraging the structure of Equivariant Linear maps <equivariant-linear-maps-example>`: .. math:: \mathbf{A}_{\mathrm{iso}} = \mathbf{Q}_{\mathcal{Y}}^\top \mathbf{A} \mathbf{Q}_{\mathcal{X}} = \bigoplus_{k\in[1,n_{\text{iso}}]} \mathbf{A}^{(k)}, \qquad \mathbf{A}^{(k)} = \sum_{s=1}^{S_k}\mathbf{\Theta}^{(k)}_s \otimes \mathbf{\Psi}^{(k)}_s. Where :math:`\{\mathbf{\Psi}^{(k)}_s \in \mathbb{R}^{d_k \times d_k}\}_{s=1}^{S_k}` is a fixed basis of the endomorphism :math:`\mathrm{End}_{\mathbb{G}}(\hat{\rho}_k)`, and :math:`\mathbf{\Theta}^{(k)}_s \in \mathbb{R}^{m_k^{\mathcal{Y}} \times m_k^{\mathcal{X}}}` are the free parameters (or degrees of freedom) of the equivariant map, serving as basis expandion coefficients. Consequently, this function projects the input map :math:`\mathbf{W}` to isotypic coordinates: :math:`\mathbf{W}_{\mathrm{iso}} = \mathbf{Q}_{\mathcal{Y}}^T \mathbf{W} \mathbf{Q}_{\mathcal{X}}`, and computes the coefficients :math:`\mathbf{\Theta}^{(k)}` of the projected map in each isotypic block: .. math:: \mathbf{\Theta}^{(k)}_{o,i,s} = \frac{\langle \mathbf{W}^{(k)}_{o,i}, \mathbf{\Psi}^{(k)}_s\rangle_F} {\lVert \mathbf{\Psi}^{(k)}_s \rVert_F^2}, \qquad \forall \; k \in [1, n_{\text{iso}}], o \in [1, m_k^{\mathcal{Y}}], i \in [1, m_k^{\mathcal{X}}], s \in [1, S_k]. Args: W (:class:`~torch.Tensor`): Dense linear map (or batch of maps) from :math:`\mathcal{X}` to :math:`\mathcal{Y}` with shape :math:`(..., D_y, D_x)`. rep_x (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\mathcal{X}}`. rep_y (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\mathcal{Y}}`. tensor_cache (:class:`IsotypicTensorCache`, optional): Optional tensor cache override containing ``Q_out``, ``Q_in_inv``, ``endo_basis_flat``, and ``endo_norm_sq`` keyed by irrep id. When provided, these tensors are reused directly, without copying, and therefore must already have dtype and device compatible with ``W``. In particular, the function reads ``tensor_cache.Q_out``, ``tensor_cache.Q_in_inv``, ``tensor_cache.endo_basis_flat[irrep_id]``, and ``tensor_cache.endo_norm_sq[irrep_id]`` for each shared irrep. Returns: tuple[torch.Tensor, torch.Tensor, dict[tuple[int, ...], IsoSpaceProjection]]: Tuple with entries: - ``Q_out``: matrix :math:`\mathbf{Q}_{\mathcal{Y}}`. If ``tensor_cache`` is ``None``, this is a tensor copy with dtype and device matching ``W``. Otherwise, ``tensor_cache.Q_out`` is returned directly. - ``Q_in_inv``: matrix :math:`\mathbf{Q}_{\mathcal{X}}^{\top}`. If ``tensor_cache`` is ``None``, this is a tensor copy with dtype and device matching ``W``. Otherwise, ``tensor_cache.Q_in_inv`` is returned directly. - ``projection_iso_spaces``: dictionary keyed by irrep identifier ``irrep_id``. Each ``projection_iso_spaces[irrep_id]`` is an :class:`IsoSpaceProjection` containing: - ``coeff`` (:class:`~torch.Tensor`): projection coefficients :math:`\mathbf{\Theta}^{(k)}` of shape :math:`(..., m_k^{\mathcal{Y}} m_k^{\mathcal{X}}, S_k)`. - ``endo_basis_flat`` (:class:`~torch.Tensor`): stacked flattened endomorphism basis :math:`\operatorname{flat}(\mathbf{\Psi}^{(k)}) := [\operatorname{flat}(\mathbf{\Psi}^{(k)}_1), \ldots, \operatorname{flat}(\mathbf{\Psi}^{(k)}_{S_k})]^\top \in \mathbb{R}^{S_k \times d_k^2}`. If ``tensor_cache`` is provided, this entry is ``tensor_cache.endo_basis_flat[irrep_id]`` returned directly. - ``m_out`` (``int``): output multiplicity :math:`m_k^{\mathcal{Y}}`. - ``m_in`` (``int``): input multiplicity :math:`m_k^{\mathcal{X}}`. - ``d_k`` (``int``): irrep dimension :math:`d_k = \dim(\hat{\rho}_k)`. - ``out_slice`` (``slice``): slice locating :math:`\mathbf{W}^{(k)}` inside :math:`\mathbf{W}_{\mathrm{iso}}`. - ``in_slice`` (``slice``): slice locating :math:`\mathbf{W}^{(k)}` inside :math:`\mathbf{W}_{\mathrm{iso}}`. Shape: - **W**: :math:`(..., D_y, D_x)`. - **Q_out**: :math:`(D_y, D_y)`. - **Q_in_inv**: :math:`(D_x, D_x)`. """ if W.shape[-2:] != (rep_y.size, rep_x.size): raise ValueError(f"Expected W shape (..., {rep_y.size}, {rep_x.size}), got {tuple(W.shape)}") rep_X_iso, rep_Y_iso, hom_iso_spaces, _ = _hom_irrep_iso_spaces(rep_x=rep_x, rep_y=rep_y) if tensor_cache is None: Q_out = _cached_rep_matrix(rep=rep_Y_iso, key="Q", matrix=rep_Y_iso.change_of_basis, like=W) Q_in_inv = _cached_rep_matrix(rep=rep_X_iso, key="Q_inv", matrix=rep_X_iso.change_of_basis_inv, like=W) endo_basis_flat_cache = None endo_norm_sq_cache = None else: Q_out, Q_in_inv, endo_basis_flat_cache, endo_norm_sq_cache = _validate_tensor_cache( tensor_cache=tensor_cache, rep_X_iso=rep_X_iso, rep_Y_iso=rep_Y_iso, require_endo_norm_sq=True, ) Q_out_inv = Q_out.mT Q_in = Q_in_inv.mT # Move dense map to isotypic coordinates: # W_iso = Q_out^{-1} W Q_in # Shape stays [..., d_out, d_in]. W_iso = (Q_out_inv @ W) @ Q_in leading_shape = W_iso.shape[:-2] batch_axes = tuple(range(len(leading_shape))) iso_space_metadata = {} for iso_space in hom_iso_spaces: irrep_id = iso_space["irrep_id"] out_slice = iso_space["out_slice"] in_slice = iso_space["in_slice"] m_out = iso_space["m_out"] m_in = iso_space["m_in"] d_k = iso_space["d_k"] if tensor_cache is None: irrep = rep_X_iso.group.irrep(*irrep_id) endo_basis_flat, endo_norm_sq = _cached_irrep_endomorphism_basis(irrep=irrep, like=W) else: if irrep_id not in endo_basis_flat_cache or irrep_id not in endo_norm_sq_cache: raise ValueError(f"tensor_cache missing endomorphism tensors for irrep {irrep_id}") endo_basis_flat = endo_basis_flat_cache[irrep_id] endo_norm_sq = endo_norm_sq_cache[irrep_id] if not isinstance(endo_basis_flat, torch.Tensor) or not isinstance(endo_norm_sq, torch.Tensor): raise TypeError(f"tensor_cache entries for irrep {irrep_id} must be torch.Tensor instances") if endo_basis_flat.ndim != 2 or endo_basis_flat.shape[1] != d_k * d_k: raise ValueError( f"Expected tensor_cache['endo_basis_flat'][{irrep_id}] shape (S_k, {d_k * d_k}), " f"got {tuple(endo_basis_flat.shape)}" ) if endo_norm_sq.shape != (endo_basis_flat.shape[0],): raise ValueError( f"Expected tensor_cache['endo_norm_sq'][{irrep_id}] shape {(endo_basis_flat.shape[0],)}, " f"got {tuple(endo_norm_sq.shape)}" ) block = W_iso[..., out_slice, in_slice] block = block.view(*leading_shape, m_out, d_k, m_in, d_k) block = block.permute(*batch_axes, -4, -2, -3, -1) # [..., m_out, m_in, d_k, d_k] block_flat = block.reshape(*leading_shape, m_out * m_in, d_k * d_k) # [..., m_out*m_in, d_k^2] # coeff[..., p, s] = <block_{p}, E_s> / ||E_s||^2 coeff = block_flat.matmul(endo_basis_flat.mT) coeff = coeff / endo_norm_sq iso_space_metadata[irrep_id] = dict( coeff=coeff, endo_basis_flat=endo_basis_flat, m_out=m_out, m_in=m_in, d_k=d_k, out_slice=out_slice, in_slice=in_slice, ) return Q_out, Q_in_inv, iso_space_metadata
def _project_to_irrep_endomorphism_basis( A: torch.Tensor, rep_x: Representation, rep_y: Representation, ) -> torch.Tensor: r"""Project a dense linear map between two isotypic spaces onto the equivariant subspace. Let :math:`(\mathcal{X}, \rho_{\mathcal{X}})` and :math:`(\mathcal{Y}, \rho_{\mathcal{Y}})` be two isotypic vector spaces of the same irrep type :math:`k`, with multiplicities :math:`m_k^{\mathcal{X}}` and :math:`m_k^{\mathcal{Y}}`. This function projects a dense linear map :math:`\mathbf{A}:\mathcal{X}\to\mathcal{Y}` onto :math:`\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})`. Args: A (:class:`~torch.Tensor`): Dense linear map of shape :math:`(m_k^{\mathcal{Y}} d_k,\; m_k^{\mathcal{X}} d_k)`. rep_x (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{X}}` of the isotypic space :math:`\mathcal{X}`. rep_y (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\mathcal{Y}}` of the isotypic space :math:`\mathcal{Y}`. Returns: A_equiv (:class:`~torch.Tensor`): Projected map of shape :math:`(m_k^{\mathcal{Y}} d_k,\; m_k^{\mathcal{X}} d_k)` satisfying :math:`\rho_{\mathcal{Y}}(g)\mathbf{A}_{\mathrm{equiv}} = \mathbf{A}_{\mathrm{equiv}}\rho_{\mathcal{X}}(g)` for all :math:`g \in \mathbb{G}`. """ irrep_id = rep_x.irreps[0] irrep = rep_x.group.irrep(*irrep_id) assert A.shape == (rep_y.size, rep_x.size), "Expected A: X -> Y" assert len(rep_x._irreps_multiplicities) == 1, f"Expected rep with a single irrep type, got {rep_x.irreps}" assert len(rep_y._irreps_multiplicities) == 1, f"Expected rep with a single irrep type, got {rep_y.irreps}" assert irrep_id == rep_y.irreps[0], f"Irreps {irrep_id} != {rep_y.irreps[0]}. Hence A=0" x_in_iso_basis = np.allclose(rep_x.change_of_basis_inv, np.eye(rep_x.size), atol=1e-6, rtol=1e-4) assert x_in_iso_basis, "Expected X to be in spectral/isotypic basis" y_in_iso_basis = np.allclose(rep_y.change_of_basis_inv, np.eye(rep_y.size), atol=1e-6, rtol=1e-4) assert y_in_iso_basis, "Expected Y to be in spectral/isotypic basis" m_x = rep_x._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in X m_y = rep_y._irreps_multiplicities[irrep_id] # Multiplicity of the irrep in Y # Get the basis of endomorphisms of the irrep (B, d, d) B = 1 | 2 | 4 irrep_end_basis = torch.tensor(irrep.endomorphism_basis(), device=A.device, dtype=A.dtype) A_irreps = A.view(m_y, irrep.size, m_x, irrep.size).permute(0, 2, 1, 3).contiguous() # Compute basis expansion coefficients of each irrep cross-covariance in basis of End(irrep) ======== # Frobenius inner product <C , Ψ_b> = Σ_{i,j} C_{ij} Ψ_b,ij A_irreps_basis_coeff = torch.einsum("mnij,bij->mnb", A_irreps, irrep_end_basis) # (m_y , m_x , B) # squared norms ‖Ψ_b‖² (only once, very small) basis_coeff_norms = torch.einsum("bij,bij->b", irrep_end_basis, irrep_end_basis) # (B,) A_irreps_basis_coeff = A_irreps_basis_coeff / basis_coeff_norms[None, None] A_irreps = torch.einsum("...b,bij->...ij", A_irreps_basis_coeff, irrep_end_basis) # (m_y , m_x , d , d) # Reshape to (my * d, mx * d) A_equiv = A_irreps.permute(0, 2, 1, 3).reshape(m_y * irrep.size, m_x * irrep.size) return A_equiv def _cached_irrep_endomorphism_basis( irrep: Representation, like: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: cache = irrep.attributes.setdefault("_endo_basis_flat_cache", {}) cache_key = (like.device.type, like.device.index, like.dtype) if cache_key not in cache: endo_basis = torch.as_tensor(irrep.endomorphism_basis(), device=like.device, dtype=like.dtype).contiguous() endo_basis_flat = endo_basis.view(endo_basis.size(0), -1) endo_norm_sq = torch.einsum("sd,sd->s", endo_basis_flat, endo_basis_flat) cache[cache_key] = (endo_basis_flat, endo_norm_sq) return cache[cache_key] def _hom_irrep_iso_spaces( rep_x: Representation, rep_y: Representation, ) -> tuple[Representation, Representation, list[dict], int]: from symm_learning.representation_theory import isotypic_decomp_rep if rep_x.group != rep_y.group: raise ValueError(f"Expected same group, got {rep_x.group} and {rep_y.group}") rep_X_iso = isotypic_decomp_rep(rep_x) rep_Y_iso = isotypic_decomp_rep(rep_y) iso_idx_X = rep_X_iso.attributes["isotypic_subspace_dims"] iso_idx_Y = rep_Y_iso.attributes["isotypic_subspace_dims"] common_irreps = sorted(set(rep_X_iso.irreps).intersection(set(rep_Y_iso.irreps))) hom_iso_spaces = [] dof_offset = 0 for irrep_id in common_irreps: out_slice = iso_idx_Y[irrep_id] in_slice = iso_idx_X[irrep_id] m_out = rep_Y_iso._irreps_multiplicities[irrep_id] m_in = rep_X_iso._irreps_multiplicities[irrep_id] irrep = rep_X_iso.group.irrep(*irrep_id) d_k = irrep.size dim_endo_basis = len(irrep.endomorphism_basis()) irrep_dim = m_out * m_in * dim_endo_basis hom_iso_spaces.append( dict( irrep_id=irrep_id, dim_endo_basis=dim_endo_basis, m_out=m_out, m_in=m_in, d_k=d_k, out_slice=out_slice, in_slice=in_slice, hom_basis_slice=slice(dof_offset, dof_offset + irrep_dim), ) ) dof_offset += irrep_dim return rep_X_iso, rep_Y_iso, hom_iso_spaces, dof_offset