"""Symmetric Learning - Representation Theory Utilities.
Tools for working with group representations, homomorphism bases, and irreducible
decompositions. These utilities support the construction of equivariant neural network
layers and the analysis of symmetric vector spaces.
Classes
-------
GroupHomomorphismBasis
Handle bases of equivariant linear maps between representation spaces.
Functions
---------
isotypic_decomp_rep
Decompose a representation into isotypic components.
direct_sum
Direct sum of representations.
irreps_stats
Statistics about irreducible representations in a representation.
escnn_representation_form_mapping
Map between ESCNN representation forms.
is_complex_irreducible
Check if a representation is complex irreducible.
decompose_representation
Full decomposition of a representation into irreducibles.
"""
from __future__ import annotations
import functools
import itertools
import logging
from collections import OrderedDict
from typing import Callable
import numpy as np
import torch
from escnn.group import Group, GroupElement, Representation
from scipy.linalg import block_diag
from symm_learning.utils import CallableDict
logger = logging.getLogger(__name__)
# Global cache for basis elements storage.
_cache_ = {}
[docs]
class GroupHomomorphismBasis(torch.nn.Module):
r"""Basis handler for :math:`\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})`.
This module provides:
1. Synthesis of equivariant maps from basis coefficients.
2. Orthogonal projection of dense matrices onto the equivariant subspace.
Using the :ref:`isotypic decomposition <isotypic-decomposition-example>`
(:func:`~symm_learning.representation_theory.isotypic_decomp_rep`) of
:math:`\rho_{\mathcal{X}}` and :math:`\rho_{\mathcal{Y}}`, each shared irrep type :math:`k` contributes a block
with structure
.. math::
\mathbf{A}^{(k)}_{o,i} = \sum_{s=1}^{S_k}\theta^{(k)}_{s,o,i}\,\mathbf{\Psi}^{(k)}_s,
where :math:`\{\mathbf{\Psi}^{(k)}_s\}_{s=1}^{S_k}` is a basis of
:math:`\operatorname{End}_\mathbb{G}(\hat{\rho}_k)`. This induces
:math:`m^{\text{out}}_k m^{\text{in}}_k S_k` degrees of freedom for type :math:`k`.
Core utilities:
- :meth:`forward`: map basis coefficients to a dense equivariant matrix
:math:`\mathbf{W}\in\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})`.
- :meth:`orthogonal_projection`: project any dense matrix onto
:math:`\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})` in Frobenius norm.
- :meth:`initialize_params`: sample valid initialization either in coefficient space or as a dense equivariant
matrix.
.. rubric:: Examples
Build and synthesize an equivariant matrix with :meth:`forward`.
.. code-block:: pycon
>>> from escnn.group import CyclicGroup
>>> G = CyclicGroup(4)
>>> rep = direct_sum([G.regular_representation] * 2)
>>> basis = GroupHomomorphismBasis(rep, rep)
>>> w_dof = torch.randn(basis.dim)
>>> W = basis(w_dof)
>>> # W satisfies rho(g) @ W == W @ rho(g) for all g in G
Project an unconstrained matrix with :meth:`orthogonal_projection`.
.. code-block:: pycon
>>> W0 = torch.randn(rep.size, rep.size)
>>> W_proj = basis.orthogonal_projection(W0)
>>> # W_proj is the closest equivariant matrix to W0 in Frobenius norm
Sample initialization with :meth:`initialize_params`.
.. code-block:: pycon
>>> w0 = basis.initialize_params("xavier_normal", return_dense=False)
>>> W0 = basis.initialize_params("xavier_normal", return_dense=True)
>>> # both parameterize equivariant maps
Attributes:
G (:class:`~escnn.group.group.Group`): Symmetry group shared by in_rep and out_rep.
in_rep (:class:`~escnn.group.representation.Representation`): Input representation
:math:`\rho_{\mathcal{X}}` rewritten in an isotypic basis.
out_rep (:class:`~escnn.group.representation.Representation`): Output representation
:math:`\rho_{\mathcal{Y}}` rewritten in an isotypic basis.
basis_expansion (:class:`str`): Strategy ("memory_heavy" or "isotypic_expansion") controlling
storage/perf trade-offs.
common_irreps (list[tuple]): Irrep identifiers present in both in_rep and out_rep.
iso_blocks (`dict <https://docs.python.org/3/library/stdtypes.html#dict>`_): Per-irrep metadata for
irreps shared by in_rep/out_rep, keys:
- out_slice / in_slice: slice selecting the isotypic coordinates in out/in reps.
- mul_out / mul_in: multiplicities of the irrep in out/in reps.
- irrep_dim: Dimension of the irreducible representation :math:`d_k`.
- endomorphism_basis: basis of :math:`\operatorname{End}_\mathbb{G}(\hat{\rho}_k)` with shape
:math:`(S_k, d_k, d_k)`.
- dim_hom_basis: Dimension of the homomorphism between isotypic spaces of type :math:`k`, i.e.
:math:`\dim(\operatorname{Hom}_\mathbb{G}(\rho_k^{\mathcal{X}}, \rho_k^{\mathcal{Y}})) = m_{out}
\cdot m_{in} \cdot S_k`.
- hom_basis_slice: slice containing the degrees of freedom associated to
:math:`\dim(\operatorname{Hom}_\mathbb{G}(\rho_k^{\mathcal{X}}, \rho_k^{\mathcal{Y}}))` from a vector of
shape :math:`(\dim(\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})),)`.
basis_elements (:class:`~torch.Tensor`): Full dense basis stack
``(dim, out_rep.size, in_rep.size)`` when ``basis_expansion="memory_heavy"``.
basis_norm_sq (:class:`~torch.Tensor`): Squared Frobenius norms of basis elements when
``basis_expansion="memory_heavy"``.
endo_basis_flat_<irrep_id> (:class:`~torch.Tensor`): Per-irrep flattened endomorphism bases
``(S_k, d_k*d_k)`` when ``basis_expansion="isotypic_expansion"``.
endo_basis_norm_sq_<irrep_id> (:class:`~torch.Tensor`): Per-irrep squared norms of the flattened
endomorphism bases with isotypic expansion.
Q_in_inv (:class:`~torch.Tensor`): Change-of-basis matrix :math:`Q_{in}^{-1}` cached as buffer for
isotypic expansion.
Q_out (:class:`~torch.Tensor`): Change-of-basis matrix :math:`Q_{out}` cached as buffer for
isotypic expansion.
"""
def __init__(
self, in_rep: Representation, out_rep: Representation, basis_expansion: str = "isotypic_expansion"
) -> None:
r"""Construct the equivariant basis and cache block-wise metadata.
Args:
in_rep: Input representation :math:`\rho_{\mathcal{X}}` whose
:ref:`isotypic decomposition <isotypic-decomposition-example>` has size `in_rep.size`.
The change of basis is applied in-place and used to define the rows of the basis matrices.
out_rep: Output representation :math:`\rho_{\mathcal{Y}}`, treated analogously to `in_rep` with total size
`out_rep.size`.
basis_expansion: Strategy for realizing the basis. ``"memory_heavy"`` stores the full stack of basis
elements with shape `(dim(Hom_G(in_rep, out_rep)), out_rep.size, in_rep.size)`, maximizing speed at the
cost of one dense matrix per basis element. ``"isotypic_expansion"`` keeps only the tiny irrep
endomorphism bases `(S_k, d_k, d_k)` and reconstructs on the fly—much lighter on memory, moderately
slower.
Raises:
AssertionError: If `in_rep` and `out_rep` do not belong to the same symmetry group.
"""
super().__init__()
assert in_rep.group == out_rep.group, f"in group: {in_rep.group} != out group: {out_rep.group}"
self.G = in_rep.group
self.basis_expansion = basis_expansion
self.in_rep = isotypic_decomp_rep(in_rep)
self.out_rep = isotypic_decomp_rep(out_rep)
dtype = torch.get_default_dtype()
# Common irreps defining the only non-zero blocks in isotypic basis of Hom_G(in_rep, out_rep)
self.common_irreps = sorted(set(self.in_rep.irreps).intersection(set(self.out_rep.irreps)))
self.iso_blocks = {}
dof_offset = 0
for irrep_id in self.common_irreps:
irrep = self.G.irrep(*irrep_id)
# Slices defining the location of the isotypic subspace block in Hom_G(in_rep, out_rep)
out_slice = self.out_rep.attributes["isotypic_subspace_dims"][irrep_id]
in_slice = self.in_rep.attributes["isotypic_subspace_dims"][irrep_id]
# Multiplicities of the irrep of type k=irrep_id in in_rep and out_rep
mul_out = self.out_rep._irreps_multiplicities[irrep_id]
mul_in = self.in_rep._irreps_multiplicities[irrep_id]
# Endomorphism basis of the irrep End_G(k=irrep_id)
irrep_end_basis = torch.tensor(irrep.endomorphism_basis(), dtype=dtype) # [D_k, d_k, d_k]
# Dimension of basis of homomorphisms between the two isotypic subspaces of type k=irrep_id
dim_irrep_hom_basis = mul_out * mul_in * irrep_end_basis.size(0) # dim(Hom_G(V_k^{in}, V_k^{out}))
# Store block info for inference time use.
self.iso_blocks[irrep_id] = dict(
out_slice=out_slice,
in_slice=in_slice,
endomorphism_basis=irrep_end_basis,
mul_out=mul_out,
mul_in=mul_in,
irrep_dim=irrep.size,
dim_hom_basis=dim_irrep_hom_basis,
hom_basis_slice=slice(dof_offset, dof_offset + dim_irrep_hom_basis),
)
dof_offset += dim_irrep_hom_basis
# Total dimension of the homomorphism space Hom_G(in_rep, out_rep)
self._dim_homomorphism = dof_offset
if self.basis_expansion == "memory_heavy":
# Construct full basis of shape (dim(Hom_G(in_rep, out_rep)), out_rep.size, in_rep.size)
basis_elements = self._build_fullsize_homomorphism_basis()
# Register contiguous buffer such that flattening returns a view.
# This register is memory heavy as fuck, but enables ultra fast forward/backward passes.
self.register_buffer("basis_elements", basis_elements.contiguous())
basis_norm_sq = torch.einsum("sab,sab->s", self.basis_elements, self.basis_elements)
self.register_buffer("basis_norm_sq", basis_norm_sq)
elif self.basis_expansion == "isotypic_expansion":
# We store as buffers only each irrep endomorphism basis.
# This is ultra memory efficient, but mildly slower at runtime.
for irrep_id, irrep_metadata in self.iso_blocks.items():
endo_basis = irrep_metadata["endomorphism_basis"].contiguous() # [S_k, d_k, d_k]
endo_basis_flat = endo_basis.view(endo_basis.size(0), -1)
self.register_buffer(f"endo_basis_flat_{irrep_id}", endo_basis_flat)
endo_basis_norm_sq = torch.einsum("sd,sd->s", endo_basis_flat, endo_basis_flat)
self.register_buffer(f"endo_basis_norm_sq_{irrep_id}", endo_basis_norm_sq)
self.register_buffer("Q_in_inv", torch.tensor(self.in_rep.change_of_basis_inv, dtype=dtype))
self.register_buffer("Q_out", torch.tensor(self.out_rep.change_of_basis, dtype=dtype))
else:
raise NotImplementedError(f"Basis expansion '{self.basis_expansion}' not implemented yet.")
[docs]
def forward(self, w_dof: torch.Tensor) -> torch.Tensor:
r"""Return equivariant linear map from coefficients.
The returned matrix satisfies
.. math::
\rho_{\mathcal{Y}}(g)\mathbf{W} = \mathbf{W}\rho_{\mathcal{X}}(g), \quad \forall g\in\mathbb{G}.
Args:
w_dof (:class:`~torch.Tensor`): Basis expansion coefficients of shape ``(D,)`` or ``(..., D)``, where
``D = dim(Hom_G(in_rep, out_rep))``.
Returns:
:class:`~torch.Tensor`: Dense matrix of shape
``(out_rep.size, in_rep.size)`` or ``(B, out_rep.size, in_rep.size)``.
Example:
>>> W = basis(w_dof)
>>> # equivariance constraint:
>>> # rho_out(g) @ W == W @ rho_in(g) for all g
"""
assert w_dof.shape[-1] == (self.dim), f"Expected w_dof shape (B,{self.dim}), but got {w_dof.shape}"
leading_shape = w_dof.shape[:-1]
if self.basis_expansion == "memory_heavy":
W_flat = torch.matmul(w_dof, self.basis_elements.view(self.dim, -1)) # [out_rep.size * in_rep.size]
W = W_flat.view(*leading_shape, self.out_rep.size, self.in_rep.size)
elif self.basis_expansion == "isotypic_expansion":
W = w_dof.new_zeros(*leading_shape, self.out_rep.size, self.in_rep.size)
for irrep_id, irrep_metadata in self.iso_blocks.items():
m_out, m_in = irrep_metadata["mul_out"], irrep_metadata["mul_in"]
out_slice, in_slice = irrep_metadata["out_slice"], irrep_metadata["in_slice"]
d_k = irrep_metadata["irrep_dim"]
hom_basis_slice = irrep_metadata["hom_basis_slice"]
endo_basis_flat = getattr(self, f"endo_basis_flat_{irrep_id}") # [S_k, d_k * d_k]
theta_k = w_dof[..., hom_basis_slice].view(
*leading_shape, m_out * m_in, endo_basis_flat.size(0)
) # [*, m_out*m_in, S_k]
# Compute basis expansion for this irrep block
coeffs = torch.matmul(theta_k, endo_basis_flat) # [*, m_out*m_in, d_k*d_k]
# [m_out*m_in, d_k*d_k] -> [m_out, m_in, d_k, d_k] -> [m_out, d_k, m_in, d_k] -> [m_out*d_k, m_in*d_k]
block = coeffs.view(*leading_shape, m_out, m_in, d_k, d_k)
block = block.permute(*range(len(leading_shape)), -4, -2, -3, -1).reshape(
*leading_shape, m_out * d_k, m_in * d_k
)
W[..., out_slice, in_slice] = block
W = self.Q_out @ (W @ self.Q_in_inv)
return W
[docs]
def orthogonal_projection(self, W: torch.Tensor) -> torch.Tensor:
r"""Project a dense matrix onto :math:`\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})`.
The projection :math:`\Pi_{\mathrm{Hom}}(\mathbf{W})` is orthogonal under the Frobenius inner product and
enforces
.. math::
\rho_{\mathcal{Y}}(g)\,\Pi_{\mathrm{Hom}}(\mathbf{W})
= \Pi_{\mathrm{Hom}}(\mathbf{W})\,\rho_{\mathcal{X}}(g),
\quad \forall g\in\mathbb{G}.
Args:
W (:class:`~torch.Tensor`): Weight matrix of shape ``(..., out_rep.size, in_rep.size)`` in the
original basis.
Returns:
:class:`~torch.Tensor`: Projection of ``W`` onto the equivariant subspace, matching the input shape.
Note:
The projection is orthogonal with respect to the Frobenius inner product. The selected
``basis_expansion`` controls how the projection is computed (speed vs. memory) but not the result.
Example:
>>> W_proj = basis.orthogonal_projection(W)
>>> # W_proj satisfies:
>>> # rho_out(g) @ W_proj == W_proj @ rho_in(g) for all g
>>> # and is idempotent under projection:
>>> # basis.orthogonal_projection(W_proj) == W_proj
"""
assert W.shape[-2:] == (self.out_rep.size, self.in_rep.size), (
f"Expected weight shape (..., {self.out_rep.size}, {self.in_rep.size}), got {W.shape}"
)
if self.basis_expansion == "memory_heavy":
basis_flat = self.basis_elements.view(self.dim, -1) # [S, d_out*d_in]
W_flat = W.view(*W.shape[:-2], -1) # [..., d_out*d_in]
# coeff = <W, B_s> / ||B_s||^2, yields [..., S]
coeff = W_flat.matmul(basis_flat.mT) / self.basis_norm_sq
# Recompose: sum_s coeff_s * B_s, returning [..., d_out*d_in]
W_proj_flat = coeff.matmul(basis_flat)
return W_proj_flat.view(*W.shape)
if self.basis_expansion == "isotypic_expansion":
Q_out_inv = self.Q_out.mT
Q_in = self.Q_in_inv.mT
W_iso_in = (Q_out_inv @ W) @ Q_in # [..., d_out, d_in] in isotypic basis
W_iso = torch.zeros_like(W_iso_in) # accumulator in isotypic basis
leading_shape = W_iso_in.shape[:-2] # batch (possibly empty) dims
for irrep_id, irrep_metadata in self.iso_blocks.items():
m_out, m_in = irrep_metadata["mul_out"], irrep_metadata["mul_in"]
out_slice, in_slice = irrep_metadata["out_slice"], irrep_metadata["in_slice"]
d_k = irrep_metadata["irrep_dim"]
endo_basis_flat = getattr(self, f"endo_basis_flat_{irrep_id}")
endo_norm_sq = getattr(self, f"endo_basis_norm_sq_{irrep_id}")
block = W_iso_in[..., out_slice, in_slice]
block = block.view(*leading_shape, m_out, d_k, m_in, d_k)
block = block.permute(*range(len(leading_shape)), -4, -2, -3, -1) # [..., m_out, m_in, d_k, d_k]
block_flat = block.reshape(*leading_shape, m_out * m_in, d_k * d_k) # [..., m_out*m_in, d_k^2]
# coeff[..., (o,i), s] = <block_{o,i}, E_s> / ||E_s||^2
coeff = block_flat.matmul(endo_basis_flat.mT)
coeff = coeff / endo_norm_sq
block_proj_flat = coeff.matmul(endo_basis_flat) # [..., m_out*m_in, d_k^2]
block_proj = block_proj_flat.view(*leading_shape, m_out, m_in, d_k, d_k)
block_proj = block_proj.permute(*range(len(leading_shape)), -4, -2, -3, -1)
block_proj = block_proj.reshape(*leading_shape, m_out * d_k, m_in * d_k)
W_iso_block = W_iso[..., out_slice, in_slice]
W_iso_block.copy_(block_proj)
W_proj = (self.Q_out @ W_iso) @ self.Q_in_inv # back to original basis
return W_proj
raise NotImplementedError(f"Basis expansion '{self.basis_expansion}' not implemented yet.")
@property
def dim(self) -> int:
r"""Dimension :math:`\dim(\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}}))`."""
return self._dim_homomorphism
@torch.no_grad()
def _build_fullsize_homomorphism_basis(self):
r"""Construct the basis of :math:`\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})`.
Returns:
``basis_elements`` (:class:`~torch.Tensor`): Stack of homomorphism basis elements of shape
`(dim(Hom_G(in_rep, out_rep)), out_rep.size, in_rep.size)`.
Note:
The dense basis is cached in module-level ``_cache_`` keyed by
``(self.in_rep, self.out_rep)``. Repeated calls with the same representation pair reuse the cached basis.
"""
global _cache_
cache_key = (self.in_rep, self.out_rep)
if cache_key in _cache_:
logger.debug(f"Using cached basis for Hom_G({self.in_rep}, {self.out_rep})")
basis_elements = _cache_[cache_key]["basis_elements"]
else:
logger.debug(f"Building full-size basis for Hom_G({self.in_rep}, {self.out_rep})")
basis_elements_iso_basis = []
for _, irrep_metadata in self.iso_blocks.items():
irrep_end_basis = irrep_metadata["endomorphism_basis"]
dim_end_basis = irrep_end_basis.size(0)
mul_out, mul_in = irrep_metadata["mul_out"], irrep_metadata["mul_in"]
out_slice, in_slice = irrep_metadata["out_slice"], irrep_metadata["in_slice"]
irrep_dim = irrep_metadata["irrep_dim"]
# Construct the basis of Hom_G(V_k^{in}, V_k^{out}) of shape [S_k * m_out * m_in, d_out * d_in]
# This is memory-intensive but the fastest way for forward/backward passes.
irrep_basis_elements = [] # (S_k, d_out, d_in)
for s in range(dim_end_basis):
endo_basis_s = irrep_end_basis[s] # [d_k, d_k]
for o_mul in range(mul_out):
for i_mul in range(mul_in):
basis_block = torch.zeros(self.out_rep.size, self.in_rep.size) # [d_out, d_in]
row_start = out_slice.start + o_mul * irrep_dim
col_start = in_slice.start + i_mul * irrep_dim
basis_block[row_start : row_start + irrep_dim, col_start : col_start + irrep_dim] = (
endo_basis_s
)
irrep_basis_elements.append(basis_block)
# Every block contributes exactly S_k * m_out * m_in rows; no need for an emptiness check.
basis_elements_iso_basis.append(
torch.stack(irrep_basis_elements, dim=0)
) # [S_k*m_out_k*m_in_k, d_out, d_in]
# Change basis elements from isotypic basis to original input/output basis
basis_elements_iso_basis = torch.cat(
basis_elements_iso_basis, dim=0
) # [dim(Hom_G(in_rep, out_rep)), d_out, d_in]
if self.basis_expansion == "memory_heavy":
Q_out = torch.tensor(self.out_rep.change_of_basis, dtype=basis_elements_iso_basis.dtype)
Q_in_inv = torch.tensor(self.in_rep.change_of_basis_inv, dtype=basis_elements_iso_basis.dtype)
else:
Q_out, Q_in_inv = self.Q_out, self.Q_in_inv
# basis elements of shape (dim(Hom_G(in_rep, out_rep)), d_out, d_in)
basis_elements = torch.einsum("ab,sbc,cd->sad", Q_out, basis_elements_iso_basis, Q_in_inv).contiguous()
# Store in global cache
_cache_[cache_key] = {"basis_elements": basis_elements.cpu()}
return basis_elements
[docs]
@torch.no_grad()
def initialize_params(
self, scheme: str = "kaiming_uniform", return_dense: bool = False, leading_shape: int | tuple | None = None
) -> torch.Tensor:
r"""Sample valid parameters in :math:`\operatorname{Hom}_\mathbb{G}(\rho_{\mathcal{X}},\rho_{\mathcal{Y}})`.
Args:
scheme (:class:`str`): Initialization scheme (``"xavier_normal"``, ``"xavier_uniform"``,
``"kaiming_normal"``, or ``"kaiming_uniform"``).
return_dense: If ``True``, return dense weights in the original basis; otherwise return basis expansion
coefficients.
leading_shape: Optional leading dimensions (e.g., batch size or a tuple of dims). ``None`` yields no leading
dims. Examples: ``None`` → ``(dim,)`` / ``(d_out, d_in)``; ``B`` → ``(B, dim)`` / ``(B, d_out, d_in)``.
Shapes:
- return_dense=False → ``(*leading_shape, dim(Hom_G(out_rep, in_rep)))``
- return_dense=True → ``(*leading_shape, out_rep.size, in_rep.size)``
Returns:
:class:`~torch.Tensor`: Initialized parameters with the shapes above.
Notes:
- If ``return_dense=False``, output lives in coefficient space and can be passed to :meth:`forward`.
- If ``return_dense=True``, returned dense matrices already satisfy
:math:`\rho_{\mathcal{Y}}(g)\mathbf{W}=\mathbf{W}\rho_{\mathcal{X}}(g)` for all :math:`g`.
Example:
>>> w_dof = basis.initialize_params("xavier_uniform")
>>> W = basis(w_dof)
>>> W2 = basis.initialize_params("xavier_uniform", return_dense=True)
"""
if leading_shape is None:
leading_shape = ()
elif isinstance(leading_shape, int):
leading_shape = (leading_shape,)
else:
leading_shape = tuple(leading_shape)
buffer = next(self.buffers(), None)
device = buffer.device if buffer is not None else None
dtype = buffer.dtype if buffer is not None else torch.get_default_dtype()
if return_dense:
W_iso = torch.zeros((*leading_shape, self.out_rep.size, self.in_rep.size), dtype=dtype, device=device)
else:
w_dof = torch.zeros((*leading_shape, self.dim), dtype=dtype, device=device)
for _, irrep_metadata in self.iso_blocks.items():
# for out_slice, in_slice, endo_basis, m_out, m_in, irrep_dim in self.irreps_meta:
endo_basis = irrep_metadata["endomorphism_basis"]
m_out, m_in = irrep_metadata["mul_out"], irrep_metadata["mul_in"]
out_slice, in_slice = irrep_metadata["out_slice"], irrep_metadata["in_slice"]
irrep_dim = irrep_metadata["irrep_dim"]
hom_basis_slice = irrep_metadata["hom_basis_slice"]
dim_endo_basis = endo_basis.size(0)
dtype, device = endo_basis.dtype, endo_basis.device
# fans for this irrep
fan_in = dim_endo_basis * m_in
fan_out = dim_endo_basis * m_out
# isotypic_param_shape := (dim_irrep_endomorphism, irep_multiplicity_out, irrep_multiplicity_in)
isotypic_param_shape = (*leading_shape, dim_endo_basis, m_out, m_in)
if scheme == "xavier_uniform":
bound = (6.0 / (fan_in + fan_out)) ** 0.5
theta = torch.empty(isotypic_param_shape, device=device, dtype=dtype).uniform_(-bound, bound)
elif scheme == "xavier_normal":
std = (2.0 / (fan_in + fan_out)) ** 0.5
theta = torch.empty(isotypic_param_shape, device=device, dtype=dtype).normal_(0.0, std)
elif scheme in {"kaiming_normal", "he_normal"}:
std = (2.0 / fan_in) ** 0.5
theta = torch.empty(isotypic_param_shape, device=device, dtype=dtype).normal_(0.0, std)
elif scheme in {"kaiming_uniform", "he_uniform"}:
bound = (6.0 / fan_in) ** 0.5
theta = torch.empty(isotypic_param_shape, device=device, dtype=dtype).uniform_(-bound, bound)
else:
raise ValueError(f"Unknown scheme: {scheme}")
if return_dense:
# Expand irrep block (m_out * irrep_dim, m_in * irrep_dim)
block = torch.einsum("...soi,sab->...oaib", theta, endo_basis).reshape(
*leading_shape, m_out * irrep_dim, m_in * irrep_dim
)
W_iso[..., out_slice, in_slice] = block
else:
w_dof[..., hom_basis_slice] = theta.reshape(*leading_shape, -1)
if return_dense: # Change to original coordinates
Q_in_inv = torch.tensor(self.in_rep.change_of_basis_inv, dtype=W_iso.dtype, device=W_iso.device)
Q_out = torch.tensor(self.out_rep.change_of_basis, dtype=W_iso.dtype, device=W_iso.device)
return torch.einsum("ab,...bc,cd->...ad", Q_out, W_iso, Q_in_inv)
else:
return w_dof
def extra_repr(self): # noqa: D102
return f"basis_expansion={self.basis_expansion}\nin_rep={self.in_rep}\nout_rep={self.out_rep}"
[docs]
def isotypic_decomp_rep(rep: Representation) -> Representation:
r"""Return an equivalent representation disentangled into isotypic subspaces.
Given an input :class:`~escnn.group.representation.Representation`, this function computes an equivalent
representation by
updating the change of basis (and its inverse) and reordering the irreducible representations. The returned
representation is guaranteed to be disentangled into its isotypic subspaces.
A representation is considered disentangled if, in its spectral basis, the irreducible representations (irreps) are
clustered by type, i.e., all irreps of the same type are consecutive:
.. math::
\rho_{\mathcal{X}} = \mathbf{Q}\left(
\bigoplus_{k\in[1,n_{\text{iso}}]}
\bigoplus_{i\in[1,n_k]}
\hat{\rho}_k
\right)\mathbf{Q}^T
where :math:`\hat{\rho}_k` is the irreducible representation of type :math:`k`,
and :math:`n_k` is its multiplicity.
The change of basis decomposes the representation space into orthogonal isotypic subspaces:
.. math::
\mathcal{X} = \bigoplus_{k\in[1,n_{\text{iso}}]} \mathcal{X}^{(k)}.
Ordering convention:
- If present, the trivial-irrep isotypic block is placed first and denoted :math:`\mathcal{X}^{\text{inv}}`.
- Remaining isotypic blocks are sorted by subspace dimension.
Output metadata:
- ``attributes["isotypic_reps"]``: ordered mapping ``irrep_id -> isotypic representation``.
- ``attributes["isotypic_subspace_dims"]``: slice map locating each isotypic block in the disentangled basis.
- ``attributes["in_isotypic_basis"]``: boolean flag set to ``True``.
Args:
rep (:class:`~escnn.group.representation.Representation`): The input representation :math:`\rho`.
Returns:
:class:`~escnn.group.representation.Representation`: An equivalent, disentangled representation.
Note:
The decomposition is cached per symmetry group under ``rep.group.representations`` with key
``rep.name + "-Iso"``. Repeated calls with the same representation/group reuse the cached result.
"""
# If group representation is already disentangled (in isotypic basis), return it without changes.
if rep.attributes.get("in_isotypic_basis", False):
logger.debug(f"Representation {rep.name} is already in isotypic basis, returning as is.")
return rep
symm_group = rep.group
iso_rep_name = rep.name + "-Iso"
if iso_rep_name in symm_group.representations:
logger.debug(f"Returning cached {iso_rep_name}")
return symm_group.representations[iso_rep_name]
logger.debug(f"Computing isotypic decomposition of representation {rep.name}")
potential_irreps = rep.group.irreps()
isotypic_subspaces_indices = {irrep.id: [] for irrep in potential_irreps}
for pot_irrep in potential_irreps:
cur_dim = 0
for rep_irrep_id in rep.irreps:
rep_irrep = symm_group.irrep(*rep_irrep_id)
if rep_irrep == pot_irrep:
isotypic_subspaces_indices[rep_irrep_id].append(list(range(cur_dim, cur_dim + rep_irrep.size)))
cur_dim += rep_irrep.size
# Remove inactive Isotypic Spaces
for irrep in potential_irreps:
if len(isotypic_subspaces_indices[irrep.id]) == 0:
del isotypic_subspaces_indices[irrep.id]
# Each Isotypic Space will be indexed by the irrep it is associated with.
active_isotypic_reps = {}
for irrep_id, indices in isotypic_subspaces_indices.items():
irrep = symm_group.irrep(*irrep_id)
multiplicities = len(indices)
active_isotypic_reps[irrep_id] = Representation(
group=rep.group,
irreps=[irrep_id] * multiplicities,
name=f"{rep.name}-IsoSubspace{irrep_id}",
change_of_basis=np.identity(irrep.size * multiplicities),
supported_nonlinearities=irrep.supported_nonlinearities,
)
# Store attributes used in stats/linalg operations.
active_isotypic_reps[irrep_id].attributes["is_isotypic_rep"] = True
active_isotypic_reps[irrep_id].attributes["irrep_multiplicity"] = multiplicities
active_isotypic_reps[irrep_id].attributes["irrep_dim"] = irrep.size
active_isotypic_reps[irrep_id].attributes["irrep_endomorphism_basis"] = torch.tensor(irrep.endomorphism_basis())
# Impose canonical order on the Isotypic Subspaces.
# If the trivial representation is active it will be the first Isotypic Subspace.
# Then sort by dimension of the space from smallest to largest.
ordered_isotypic_reps = OrderedDict(sorted(active_isotypic_reps.items(), key=lambda item: item[1].size))
if symm_group.trivial_representation.id in ordered_isotypic_reps:
ordered_isotypic_reps.move_to_end(symm_group.trivial_representation.id, last=False)
# Required permutation to change the order of the irreps. So we obtain irreps of the same type consecutively.
oneline_permutation = []
for irrep_id, iso_rep in ordered_isotypic_reps.items():
idx = isotypic_subspaces_indices[irrep_id]
oneline_permutation.extend(idx)
oneline_permutation = np.concatenate(oneline_permutation)
expected = np.arange(rep.size)
if oneline_permutation.shape[0] != rep.size or not np.array_equal(np.sort(oneline_permutation), expected):
raise RuntimeError(
"Invalid isotypic permutation: indices must be a permutation of [0, ..., rep.size-1], "
f"got size={oneline_permutation.shape[0]} and unique={np.unique(oneline_permutation).size}"
)
P_in2iso = permutation_matrix(oneline_permutation)
Q_iso = rep.change_of_basis @ P_in2iso.T
rep_iso_basis = direct_sum(list(ordered_isotypic_reps.values()), name=iso_rep_name, change_of_basis=Q_iso)
# Get variable of indices of isotypic subspaces in the disentangled representation.
d = 0
iso_subspace_dims = {}
for irrep_id, iso_rep in ordered_isotypic_reps.items():
iso_subspace_dims[irrep_id] = slice(d, d + iso_rep.size)
d += iso_rep.size
if d != rep.size:
raise RuntimeError(f"Isotypic slices cover {d} dimensions, expected {rep.size}")
iso_supported_nonlinearities = [iso_rep.supported_nonlinearities for iso_rep in ordered_isotypic_reps.values()]
rep_iso_basis.supported_nonlinearities = functools.reduce(set.intersection, iso_supported_nonlinearities)
rep_iso_basis.attributes["isotypic_reps"] = ordered_isotypic_reps
rep_iso_basis.attributes["isotypic_subspace_dims"] = iso_subspace_dims
rep_iso_basis.attributes["in_isotypic_basis"] = True # Boolean flag useful
# Store representation in symmetry group cache:
symm_group.representations[rep_iso_basis.name] = rep_iso_basis
logger.debug(f"Stored isotypic decomposition representation {rep_iso_basis.name} in cache.")
return rep_iso_basis
[docs]
def direct_sum(
reps: list[Representation], name: str | None = None, change_of_basis: np.ndarray | None = None
) -> Representation:
r"""Return the direct sum :math:`\rho=\bigoplus_i \rho_i` of representations.
If ``name`` is not provided, the function builds one by run-length encoding consecutive repeated names:
- ``[rep_1, rep_1, rep_2, rep_1]`` -> ``"(rep_1 x 2)⊕rep_2⊕rep_1"``
This preserves order and makes repeated contiguous blocks explicit.
Args:
reps (list[:class:`~escnn.group.representation.Representation`]): Summands in the direct sum.
name (:class:`str`, optional): Name of the resulting representation. Auto-generated when ``None``.
change_of_basis (:class:`~numpy.ndarray`, optional): Optional change of basis for the resulting representation.
Returns:
:class:`~escnn.group.representation.Representation`: Direct-sum representation with
``attributes["direct_sum_reps"]`` set.
"""
from escnn.group import directsum
if len(reps) == 1:
return reps[0]
# Name is a string of original representation names, where continuous multiplicities of a
# representaton are groups [rep_1, rep_1, rep_2, rep_1] -> [rep_1 x 2, rep_2, rep_1]
if name is None:
rep_names = []
for rep_id, multiplicities in itertools.groupby(reps, lambda r: r.name):
m_list = list(multiplicities)
if len(m_list) > 1:
rep_names.append(f"({rep_id} x {len(m_list)})")
else:
rep_names.append(rep_id)
name = "⊕".join(rep_names)
out_rep = directsum(reps, name=name, change_of_basis=change_of_basis)
out_rep.attributes["direct_sum_reps"] = reps
return out_rep
[docs]
def permutation_matrix(oneline_notation):
r"""Generate a permutation matrix from one-line notation.
If ``oneline_notation = [p_0,\dots,p_{d-1}]``, the returned matrix :math:`\mathbf{P}` satisfies
:math:`(\mathbf{P}\mathbf{x})_i = x_{p_i}`.
Example:
>>> permutation_matrix([2, 0, 1])
array([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
"""
d = len(oneline_notation)
oneline_notation = np.asarray(oneline_notation, dtype=int)
if d != np.unique(oneline_notation).size:
raise ValueError("oneline_notation must describe a non-defective permutation")
if oneline_notation.min() < 0 or oneline_notation.max() >= d:
raise ValueError(f"oneline_notation entries must be in [0, {d - 1}]")
P = np.zeros((d, d), dtype=int)
P[np.arange(d), oneline_notation] = 1
return P
[docs]
def irreps_stats(irreps_ids):
r"""Compute summary statistics for a sequence of irrep identifiers.
Args:
irreps_ids (:class:`typing.Iterable`): Sequence of ESCNN irrep IDs.
Returns:
tuple:
- unique_ids: unique irrep IDs (sorted by numpy lexical order over their string representation).
- counts: multiplicity of each unique irrep.
- indices: first position where each unique irrep appears.
"""
str_ids = [str(irrep_id) for irrep_id in irreps_ids]
unique_str_ids, counts, indices = np.unique(str_ids, return_counts=True, return_index=True)
unique_ids = [eval(s) for s in unique_str_ids]
return unique_ids, counts, indices
[docs]
def is_complex_irreducible(group: Group, rep: dict[GroupElement, np.ndarray] | Callable[[GroupElement], np.ndarray]):
r"""Check complex irreducibility of :math:`\rho:\mathbb{G}\to\mathrm{GL}(\mathcal{X})`.
By Schur's lemma, :math:`\rho` is complex-irreducible iff every Hermitian matrix commuting with all
:math:`\rho(g)` is scalar. The routine searches for a non-scalar commuting Hermitian witness :math:`\mathbf{H}`.
Args:
group (:class:`~escnn.group.group.Group`): Symmetry group of the representation.
rep (dict | collections.abc.Callable): A :class:`~typing.Union`-style input.
It must be either a `dict <https://docs.python.org/3/library/stdtypes.html#dict>`_ or a
:class:`collections.abc.Callable`.
It must map each :class:`~escnn.group.group.GroupElement` to a matrix :class:`~numpy.ndarray`.
Returns:
tuple[:class:`bool`, :class:`~numpy.ndarray`]:
- ``(True, I)`` if irreducible (identity witness),
- ``(False, H)`` if reducible with a non-scalar commuting Hermitian witness.
"""
if isinstance(rep, dict):
def rep(g):
return rep[g]
else:
rep = rep
# Compute the dimension of the representation
n = rep(group.sample()).shape[0]
# Run through all r,s = 1,2,...,n
for r in range(n):
for s in range(n):
# Define H_rs
H_rs = np.zeros((n, n), dtype=complex)
if r == s:
H_rs[r, s] = 1
elif r > s:
H_rs[r, s] = 1
H_rs[s, r] = 1
else: # r < s
H_rs[r, s] = 1j
H_rs[s, r] = -1j
# Compute H
H = sum([rep(g).conj().T @ H_rs @ rep(g) for g in group.elements]) / group.order()
# If H is not a scalar matrix, then it is a matrix that commutes with all group actions.
if not np.allclose(H[0, 0] * np.eye(H.shape[0]), H):
return False, H
# No Hermitian matrix was found to commute with all group actions. This is an irreducible rep
return True, np.eye(n)
[docs]
def decompose_representation(G: Group, rep: dict[GroupElement, np.ndarray] | Callable[[GroupElement], np.ndarray]):
r"""Block-diagonalize :math:`\rho:\mathbb{G}\to\mathrm{GL}(\mathcal{X})` into invariant subspaces.
Finds a unitary matrix :math:`\mathbf{Q}` such that
.. math::
\mathbf{Q}\rho(g)\mathbf{Q}^H = \operatorname{blockdiag}(\rho_1(g), \ldots, \rho_m(g)),
\quad \forall g\in\mathbb{G},
where each :math:`\rho_i` is an invariant subrepresentation. The algorithm recursively uses commuting-Hermitian
witnesses and graph connected-components to obtain contiguous block structure.
Args:
G (:class:`~escnn.group.group.Group`): The symmetry group.
rep (dict | collections.abc.Callable): A :class:`~typing.Union`-style input.
It must be either a `dict <https://docs.python.org/3/library/stdtypes.html#dict>`_ or a
:class:`collections.abc.Callable`.
It must map each :class:`~escnn.group.group.GroupElement` to a matrix :class:`~numpy.ndarray`.
Returns:
tuple:
- subreps (list[Callable]): Decomposed subrepresentations (not guaranteed sorted by dimension).
- Q (:class:`~numpy.ndarray`): Unitary change-of-basis matrix.
"""
import networkx as nx
from networkx import Graph
eps = 1e-12
if isinstance(rep, dict):
def rep(g):
return rep[g]
else:
rep = rep
# Compute the dimension of the representation
n = rep(G.sample()).shape[0]
for g in G.elements: # Ensure the representation is unitary/orthogonal
error = np.abs((rep(g) @ rep(g).conj().T) - np.eye(n))
assert np.allclose(error, 0), f"Rep {rep} is not unitary: rep(g)@rep(g)^H=\n{(rep(g) @ rep(g).conj().T)}"
# Find Hermitian matrix non-scalar `H` that commutes with all group actions
is_irred, H = is_complex_irreducible(G, rep)
if is_irred:
return [rep], np.eye(n)
# Eigen-decomposition of matrix `H = P·A·P^-1` reveals the G-invariant subspaces/eigenspaces of the representations.
eivals, eigvects = np.linalg.eigh(H, UPLO="L")
P = eigvects.conj().T
assert np.allclose(P.conj().T @ np.diag(eivals) @ P, H)
# Eigendcomposition is not guaranteed to block_diagonalize the representation. An additional permutation of the
# rows and columns od the representation might be needed to produce a Jordan block canonical form.
# First: We want to identify the diagonal blocks. To find them we use the trick of thinking of the representation
# as an adjacency matrix of a graph. The non-zero entries of the adjacency matrix are the edges of the graph.
edges = set()
decomposed_reps = {}
for g in G.elements:
diag_rep = P @ rep(g) @ P.conj().T # Obtain block diagonal representation
diag_rep[np.abs(diag_rep) < eps] = 0 # Remove rounding errors.
non_zero_idx = np.nonzero(diag_rep)
edges.update([(x_idx, y_idx) for x_idx, y_idx in zip(*non_zero_idx)])
decomposed_reps[g] = diag_rep
# Each connected component of the graph is equivalent to the rows and columns determining a block in the diagonal
graph = Graph()
graph.add_edges_from(set(edges))
connected_components = [sorted(list(comp)) for comp in nx.connected_components(graph)]
connected_components = sorted(connected_components, key=lambda x: (len(x), min(x))) # Impose a canonical order
# If connected components are not adjacent dimensions, say subrep_1_dims = [0,2] and subrep_2_dims = [1,3] then
# We permute them to get a jordan block canonical form. I.e. subrep_1_dims = [0,1] and subrep_2_dims = [2,3].
oneline_notation = list(itertools.chain.from_iterable([list(comp) for comp in connected_components]))
PJ = permutation_matrix(oneline_notation=oneline_notation)
# After permuting the dimensions, we can assume the components are ordered in dimension
ordered_connected_components = []
idx = 0
for comp in connected_components:
ordered_connected_components.append(tuple(range(idx, idx + len(comp))))
idx += len(comp)
connected_components = ordered_connected_components
# The output of connected components is the set of nodes/row-indices of the rep.
subreps = [CallableDict() for _ in connected_components]
for g in G.elements:
for comp_id, comp in enumerate(connected_components):
block_start, block_end = comp[0], comp[-1] + 1
# Transform the decomposed representation into the Jordan Cannonical Form (jcf)
jcf_rep = PJ @ decomposed_reps[g] @ PJ.T
# Check Jordan Cannonical Form TODO: Extract this to a utils. function
above_block = jcf_rep[0:block_start, block_start:block_end]
below_block = jcf_rep[block_end:, block_start:block_end]
left_block = jcf_rep[block_start:block_end, 0:block_start]
right_block = jcf_rep[block_start:block_end, block_end:]
assert np.allclose(above_block, 0) or above_block.size == 0, "Non zero elements above block"
assert np.allclose(below_block, 0) or below_block.size == 0, "Non zero elements below block"
assert np.allclose(left_block, 0) or left_block.size == 0, "Non zero elements left of block"
assert np.allclose(right_block, 0) or right_block.size == 0, "Non zero elements right of block"
sub_g = jcf_rep[block_start:block_end, block_start:block_end]
subreps[comp_id][g] = sub_g
# Decomposition to Jordan Canonical form is accomplished by (PJ @ P) @ rep @ (PJ @ P)^-1
Q = PJ @ P
# Test decomposition.
for g in G.elements:
jcf_rep = block_diag(*[subrep[g] for subrep in subreps])
error = np.abs(jcf_rep - (Q @ rep(g) @ Q.conj().T))
assert np.allclose(error, 0), f"Q @ rep[g] @ Q^-1 != block_diag[{[f'rep{i},' for i in range(len(subreps))]}]"
return subreps, Q
def compute_character_table(G: Group, reps: list[dict[GroupElement, np.ndarray] | Representation]):
"""Computes the character table of a group for a given set of representations.
Args:
G (:class:`~escnn.group.group.Group`): Symmetry group.
reps (:class:`list`): Representations (or representation mappings) for which characters are evaluated on all
group elements. Each entry is a :class:`~typing.Union` of:
- a mapping from :class:`~escnn.group.group.GroupElement` to :class:`~numpy.ndarray`, or
- an :class:`~escnn.group.representation.Representation`.
Returns:
:class:`~numpy.ndarray`: The character table of shape `(n_reps, G.order())`.
"""
n_reps = len(reps)
table = np.zeros((n_reps, G.order()), dtype=complex)
for i, rep in enumerate(reps):
for j, g in enumerate(G.elements):
table[i, j] = rep.character(g) if isinstance(rep, Representation) else np.trace(rep(g))
return table
def map_character_tables(in_table: np.ndarray, reference_table: np.ndarray):
"""Find a representation of a group in the set of irreducible representations."""
n_in_reps = in_table.shape[0]
out_ids, multiplicities = [], []
for in_id in range(n_in_reps):
character_orbit = in_table[in_id, :]
orbit_error = np.isclose(np.abs(reference_table - character_orbit), 0)
match_idx = np.argwhere(np.all(orbit_error, axis=1)).flatten()
multiplicity = len(match_idx)
out_ids.append(match_idx), multiplicities.append(multiplicity)
return multiplicities, out_ids
def cplx_isotypic_decomposition(G: Group, rep: Callable[[GroupElement], np.ndarray]):
"""Perform the :ref:`isotypic decomposition <isotypic-decomposition-example>` of a unitary representation.
This routine decomposes the representation into complex irreducibles.
Args:
G (:class:`~escnn.group.group.Group`): Symmetry group of the representation.
rep (dict | collections.abc.Callable): A :class:`~typing.Union`-style input.
It must be either a `dict <https://docs.python.org/3/library/stdtypes.html#dict>`_ or a
:class:`collections.abc.Callable`.
It must map each :class:`~escnn.group.group.GroupElement` to a matrix :class:`~numpy.ndarray`.
Returns:
sorted_irreps (list[dict]): List of complex irreducible representations, sorted in ascending order of
dimension.
Q (:class:`~numpy.ndarray`): Hermitian matrix such that ``Q @ rep[g] @ Q^-1`` is block diagonal, with blocks
``sorted_irreps``.
"""
if isinstance(rep, dict):
def rep(g):
return rep[g]
else:
rep = rep
n = rep(G.sample()).shape[0]
subreps, Q_internal = decompose_representation(G, rep)
found_irreps = []
Qs = []
# Check if each subrepresentation can be further decomposed.
for subrep in subreps:
n_sub = subrep(G.sample()).shape[0] # Dimension of sub representation
is_irred, _ = is_complex_irreducible(G, subrep)
if is_irred:
found_irreps.append(subrep)
Qs.append(np.eye(n_sub))
else:
# Find Q_sub such that Q_sub @ subrep[g] @ Q_sub^-1 is block diagonal, with blocks `sub_subrep`
sub_subreps, Q_sub = cplx_isotypic_decomposition(G, subrep)
found_irreps += sub_subreps
Qs.append(Q_sub)
# Sort irreps by dimension.
P, sorted_irreps = sorted_jordan_canonical_form(G, found_irreps)
# If subreps were decomposable, then these get further decomposed with an additional Hermitian matrix such that:
# Q @ rep[g] @ Q^-1 = block_diag[irreps] | Q = (Q_external @ Q_internal)
Q_external = block_diag(*Qs)
Q = P @ Q_external @ Q_internal
# Test isotypic decomposition.
assert np.allclose(Q @ Q.conj().T, np.eye(n)), "Q is not unitary."
for g in G.elements:
g_iso = block_diag(*[irrep[g] if isinstance(irrep, dict) else irrep(g) for irrep in sorted_irreps])
error = np.abs(g_iso - (Q @ rep(g) @ Q.conj().T))
assert np.allclose(error, 0), f"Q @ rep[g] @ Q^-1 != block_diag[irreps[g]], for g={g}. Error \n:{error}"
return sorted_irreps, Q
def sorted_jordan_canonical_form(group: Group, reps: list[Callable[[GroupElement], np.ndarray]]):
"""Sorts a list of representations in ascending order of dimension, and returns a permutation matrix P such that.
Args:
group (:class:`~escnn.group.group.Group`): Symmetry group of the representation.
reps (:class:`list`[:class:`collections.abc.Callable`]): List of representations to sort by dimension.
Returns:
P (:class:`~numpy.ndarray`): Permutation matrix sorting the input reps.
reps (:class:`list`[:class:`collections.abc.Callable`]): Sorted list of representations.
"""
reps_idx = range(len(reps))
reps_size = [rep(group.sample()).shape[0] for rep in reps]
sort_order = sorted(reps_idx, key=lambda idx: reps_size[idx])
if sort_order == list(reps_idx):
return np.eye(sum(reps_size)), reps
irrep_dim_start = np.cumsum([0] + reps_size[:-1])
oneline_perm = []
for idx in sort_order:
rep_size = reps_size[idx]
oneline_perm += list(range(irrep_dim_start[idx], irrep_dim_start[idx] + rep_size))
P = permutation_matrix(oneline_perm)
return P, [reps[idx] for idx in sort_order]