from __future__ import annotations
import functools
import itertools
from collections import OrderedDict
from typing import Callable, Dict, List, Union
import numpy as np
from escnn.group import Group, GroupElement, Representation, directsum
from escnn.nn import FieldType
from scipy.linalg import block_diag
from symm_learning.utils import CallableDict
[docs]
def isotypic_decomp_rep(rep: Representation) -> Representation:
r"""Return an equivalent representation disentangled into isotypic subspaces.
Given an input :class:`~escnn.group.Representation`, this function computes an
equivalent representation by updating the change of basis (and its inverse)
and reordering the irreducible representations. The returned representation
is guaranteed to be disentangled into its isotypic subspaces.
A representation is considered disentangled if, in its spectral basis, the
irreducible representations (irreps) are clustered by type, i.e., all
irreps of the same type are consecutive:
.. math::
\rho_{\mathcal{X}} = \mathbf{Q} \left( \bigoplus_{k\in[1,n_{\text{iso}}]} (\mathbf{I}_{m_{k}} \otimes
\hat{\rho}_k) \right) \mathbf{Q}^T
where :math:`\hat{\rho}_k` is the irreducible representation of type :math:`k`,
and :math:`m_{k}` is its multiplicity.
The change of basis of the representation returned by this function can be
used to decompose the representation space :math:`\mathcal{X}` into its
orthogonal isotypic subspaces:
.. math::
\mathcal{X} = \bigoplus_{k\in[1,n]} \mathcal{X}^{(k)}
Args:
rep (escnn.group.Representation): The input representation.
Returns:
escnn.group.Representation: An equivalent, disentangled representation.
"""
symm_group = rep.group
if rep.name + "-Iso" in symm_group.representations:
print(f"Returning cached {rep.name}-Iso")
return symm_group.representations[rep.name + "-Iso"]
potential_irreps = rep.group.irreps()
isotypic_subspaces_indices = {irrep.id: [] for irrep in potential_irreps}
for pot_irrep in potential_irreps:
cur_dim = 0
for rep_irrep_id in rep.irreps:
rep_irrep = symm_group.irrep(*rep_irrep_id)
if rep_irrep == pot_irrep:
isotypic_subspaces_indices[rep_irrep_id].append(list(range(cur_dim, cur_dim + rep_irrep.size)))
cur_dim += rep_irrep.size
# Remove inactive Isotypic Spaces
for irrep in potential_irreps:
if len(isotypic_subspaces_indices[irrep.id]) == 0:
del isotypic_subspaces_indices[irrep.id]
# Each Isotypic Space will be indexed by the irrep it is associated with.
active_isotypic_reps = {}
for irrep_id, indices in isotypic_subspaces_indices.items():
irrep = symm_group.irrep(*irrep_id)
multiplicities = len(indices)
active_isotypic_reps[irrep_id] = Representation(
group=rep.group,
irreps=[irrep_id] * multiplicities,
name=f"IsoSubspace {irrep_id}",
change_of_basis=np.identity(irrep.size * multiplicities),
supported_nonlinearities=irrep.supported_nonlinearities,
)
# Impose canonical order on the Isotypic Subspaces.
# If the trivial representation is active it will be the first Isotypic Subspace.
# Then sort by dimension of the space from smallest to largest.
ordered_isotypic_reps = OrderedDict(sorted(active_isotypic_reps.items(), key=lambda item: item[1].size))
if symm_group.trivial_representation.id in ordered_isotypic_reps:
ordered_isotypic_reps.move_to_end(symm_group.trivial_representation.id, last=False)
# Required permutation to change the order of the irreps. So we obtain irreps of the same type consecutively.
oneline_permutation = []
for irrep_id, iso_rep in ordered_isotypic_reps.items():
idx = isotypic_subspaces_indices[irrep_id]
oneline_permutation.extend(idx)
oneline_permutation = np.concatenate(oneline_permutation)
P_in2iso = permutation_matrix(oneline_permutation)
Q_iso = rep.change_of_basis @ P_in2iso.T
rep_iso_basis = directsum(list(ordered_isotypic_reps.values()), name=rep.name + "-Iso", change_of_basis=Q_iso)
# Get variable of indices of isotypic subspaces in the disentangled representation.
d = 0
iso_subspace_dims = {}
for irrep_id, iso_rep in ordered_isotypic_reps.items():
iso_subspace_dims[irrep_id] = slice(d, d + iso_rep.size)
d += iso_rep.size
iso_supported_nonlinearities = [iso_rep.supported_nonlinearities for iso_rep in ordered_isotypic_reps.values()]
rep_iso_basis.supported_nonlinearities = functools.reduce(set.intersection, iso_supported_nonlinearities)
rep_iso_basis.attributes["isotypic_reps"] = ordered_isotypic_reps
rep_iso_basis.attributes["isotypic_subspace_dims"] = iso_subspace_dims
return rep_iso_basis
[docs]
def field_type_to_isotypic_basis(field_type: FieldType):
"""Returns a new field type in a disentangled basis ignoring change of basis."""
rep = field_type.representation
# Organize the irreps such that we get: rep_ordered_irreps := Q (⊕_k (⊕_i^mk irrep_k)) Q^T
rep_ordered_irreps = isotypic_decomp_rep(rep)
# Get dictionary of irrep_id: (⊕_i^mk irrep_k)
iso_subspaces_reps = rep_ordered_irreps.attributes["isotypic_reps"]
# Define a field type composed of the representations of each isotypic subspace
new_field_type = FieldType(gspace=field_type.gspace, representations=list(iso_subspaces_reps.values()))
return new_field_type
[docs]
def permutation_matrix(oneline_notation):
"""Generate a permutation matrix from its oneline notation."""
d = len(oneline_notation)
assert d == np.unique(oneline_notation).size, "oneline_notation must describe a non-defective permutation"
P = np.zeros((d, d), dtype=int)
P[range(d), np.abs(oneline_notation)] = 1
return P
[docs]
def irreps_stats(irreps_ids):
"""Compute the unique irreps, their counts and indices in the input list of irreps ids."""
str_ids = [str(irrep_id) for irrep_id in irreps_ids]
unique_str_ids, counts, indices = np.unique(str_ids, return_counts=True, return_index=True)
unique_ids = [eval(s) for s in unique_str_ids]
return unique_ids, counts, indices
[docs]
def is_complex_irreducible(
group: Group, rep: Union[Dict[GroupElement, np.ndarray], Callable[[GroupElement], np.ndarray]]
):
"""Check if a representation is complex irreducible.
We check this by asserting weather non-scalar (no multiple of
identity ) Hermitian matrix `H` exists, such that `H` commutes with all group elements' representation.
If rho is irreducible, this function returns (True, H=I) where I is the identity matrix.
Otherwise, returns (False, H) where H is a non-scalar matrix that commutes with all elements' representation.
Args:
group (:class:`escnn.group.Group`): Symmetry group of the representation.
rep (Union[Dict[escnn.group.GroupElement, np.ndarray], Callable[[escnn.group.GroupElement], np.ndarray]]):
Mapping from group elements to their representation matrices.
"""
if isinstance(rep, dict):
def rep(g):
return rep[g]
else:
rep = rep
# Compute the dimension of the representation
n = rep(group.sample()).shape[0]
# Run through all r,s = 1,2,...,n
for r in range(n):
for s in range(n):
# Define H_rs
H_rs = np.zeros((n, n), dtype=complex)
if r == s:
H_rs[r, s] = 1
elif r > s:
H_rs[r, s] = 1
H_rs[s, r] = 1
else: # r < s
H_rs[r, s] = 1j
H_rs[s, r] = -1j
# Compute H
H = sum([rep(g).conj().T @ H_rs @ rep(g) for g in group.elements]) / group.order()
# If H is not a scalar matrix, then it is a matrix that commutes with all group actions.
if not np.allclose(H[0, 0] * np.eye(H.shape[0]), H):
return False, H
# No Hermitian matrix was found to commute with all group actions. This is an irreducible rep
return True, np.eye(n)
[docs]
def decompose_representation(
G: Group, rep: Union[Dict[GroupElement, np.ndarray], Callable[[GroupElement], np.ndarray]]
):
r"""Find the Hermitian matrix `Q` that block-diagonalizes the representation `rep` of group `G`.
Such that
.. math::
\mathbf{Q} \rho(g) \mathbf{Q}^\top = \operatorname{block\_ diag}(\rho_1(g), ..., \rho_m(g))
\quad \forall g \in G
Args:
G (:class:`escnn.group.Group`): The symmetry group.
rep (Union[Dict[escnn.group.GroupElement, np.ndarray], Callable[[escnn.group.GroupElement], np.ndarray]]):
The representation to decompose.
Returns:
tuple: A tuple containing:
- subreps (List[CallableDict]): A list of irreducible representations.
- Q (np.ndarray): The change of basis matrix.
"""
import networkx as nx
from networkx import Graph
eps = 1e-12
if isinstance(rep, dict):
def rep(g):
return rep[g]
else:
rep = rep
# Compute the dimension of the representation
n = rep(G.sample()).shape[0]
for g in G.elements: # Ensure the representation is unitary/orthogonal
error = np.abs((rep(g) @ rep(g).conj().T) - np.eye(n))
assert np.allclose(error, 0), f"Rep {rep} is not unitary: rep(g)@rep(g)^H=\n{(rep(g) @ rep(g).conj().T)}"
# Find Hermitian matrix non-scalar `H` that commutes with all group actions
is_irred, H = is_complex_irreducible(G, rep)
if is_irred:
return [rep], np.eye(n)
# Eigen-decomposition of matrix `H = P·A·P^-1` reveals the G-invariant subspaces/eigenspaces of the representations.
eivals, eigvects = np.linalg.eigh(H, UPLO="L")
P = eigvects.conj().T
assert np.allclose(P.conj().T @ np.diag(eivals) @ P, H)
# Eigendcomposition is not guaranteed to block_diagonalize the representation. An additional permutation of the
# rows and columns od the representation might be needed to produce a Jordan block canonical form.
# First: We want to identify the diagonal blocks. To find them we use the trick of thinking of the representation
# as an adjacency matrix of a graph. The non-zero entries of the adjacency matrix are the edges of the graph.
edges = set()
decomposed_reps = {}
for g in G.elements:
diag_rep = P @ rep(g) @ P.conj().T # Obtain block diagonal representation
diag_rep[np.abs(diag_rep) < eps] = 0 # Remove rounding errors.
non_zero_idx = np.nonzero(diag_rep)
edges.update([(x_idx, y_idx) for x_idx, y_idx in zip(*non_zero_idx)])
decomposed_reps[g] = diag_rep
# Each connected component of the graph is equivalent to the rows and columns determining a block in the diagonal
graph = Graph()
graph.add_edges_from(set(edges))
connected_components = [sorted(list(comp)) for comp in nx.connected_components(graph)]
connected_components = sorted(connected_components, key=lambda x: (len(x), min(x))) # Impose a canonical order
# If connected components are not adjacent dimensions, say subrep_1_dims = [0,2] and subrep_2_dims = [1,3] then
# We permute them to get a jordan block canonical form. I.e. subrep_1_dims = [0,1] and subrep_2_dims = [2,3].
oneline_notation = list(itertools.chain.from_iterable([list(comp) for comp in connected_components]))
PJ = permutation_matrix(oneline_notation=oneline_notation)
# After permuting the dimensions, we can assume the components are ordered in dimension
ordered_connected_components = []
idx = 0
for comp in connected_components:
ordered_connected_components.append(tuple(range(idx, idx + len(comp))))
idx += len(comp)
connected_components = ordered_connected_components
# The output of connected components is the set of nodes/row-indices of the rep.
subreps = [CallableDict() for _ in connected_components]
for g in G.elements:
for comp_id, comp in enumerate(connected_components):
block_start, block_end = comp[0], comp[-1] + 1
# Transform the decomposed representation into the Jordan Cannonical Form (jcf)
jcf_rep = PJ @ decomposed_reps[g] @ PJ.T
# Check Jordan Cannonical Form TODO: Extract this to a utils. function
above_block = jcf_rep[0:block_start, block_start:block_end]
below_block = jcf_rep[block_end:, block_start:block_end]
left_block = jcf_rep[block_start:block_end, 0:block_start]
right_block = jcf_rep[block_start:block_end, block_end:]
assert np.allclose(above_block, 0) or above_block.size == 0, "Non zero elements above block"
assert np.allclose(below_block, 0) or below_block.size == 0, "Non zero elements below block"
assert np.allclose(left_block, 0) or left_block.size == 0, "Non zero elements left of block"
assert np.allclose(right_block, 0) or right_block.size == 0, "Non zero elements right of block"
sub_g = jcf_rep[block_start:block_end, block_start:block_end]
subreps[comp_id][g] = sub_g
# Decomposition to Jordan Canonical form is accomplished by (PJ @ P) @ rep @ (PJ @ P)^-1
Q = PJ @ P
# Test decomposition.
for g in G.elements:
jcf_rep = block_diag(*[subrep[g] for subrep in subreps])
error = np.abs(jcf_rep - (Q @ rep(g) @ Q.conj().T))
assert np.allclose(error, 0), f"Q @ rep[g] @ Q^-1 != block_diag[{[f'rep{i},' for i in range(len(subreps))]}]"
return subreps, Q
[docs]
def compute_character_table(G: Group, reps: List[Union[Dict[GroupElement, np.ndarray], Representation]]):
"""Computes the character table of a group for a given set of representations.
Args:
G (:class:`escnn.group.Group`): The symmetry group.
reps (List[Union[Dict[escnn.group.GroupElement, np.ndarray], escnn.group.Representation]]): The representations
to compute the character table for.
"""
n_reps = len(reps)
table = np.zeros((n_reps, G.order()), dtype=complex)
for i, rep in enumerate(reps):
for j, g in enumerate(G.elements):
table[i, j] = rep.character(g) if isinstance(rep, Representation) else np.trace(rep(g))
return table
[docs]
def map_character_tables(in_table: np.ndarray, reference_table: np.ndarray):
"""Find a representation of a group in the set of irreducible representations."""
n_in_reps = in_table.shape[0]
out_ids, multiplicities = [], []
for in_id in range(n_in_reps):
character_orbit = in_table[in_id, :]
orbit_error = np.isclose(np.abs(reference_table - character_orbit), 0)
match_idx = np.argwhere(np.all(orbit_error, axis=1)).flatten()
multiplicity = len(match_idx)
out_ids.append(match_idx), multiplicities.append(multiplicity)
return multiplicities, out_ids
[docs]
def cplx_isotypic_decomposition(G: Group, rep: Callable[[GroupElement], np.ndarray]):
"""Perform the isotypic decomposition of unitary representation, decomposing the rep into complex irreps.
Args:
G (:class:`escnn.group.Group`): Symmetry group of the representation.
rep (Union[Dict[escnn.group.GroupElement, np.ndarray], Callable[[escnn.group.GroupElement, np.ndarray]]]):
dict/mapping from group elements to matrices, or a function that takes a group element and returns a matrix.
Returns:
sorted_irreps (List[Dict[escnn.group.GroupElement, np.ndarray]]): List of complex irreducible representations,
sorted in ascending order of dimension.
Q (:class:`numpy.ndarray`): Hermitian matrix such that Q @ rep[g] @ Q^-1 is block diagonal, with blocks
`sorted_irreps`.
"""
if isinstance(rep, dict):
def rep(g):
return rep[g]
else:
rep = rep
n = rep(G.sample()).shape[0]
subreps, Q_internal = decompose_representation(G, rep)
found_irreps = []
Qs = []
# Check if each subrepresentation can be further decomposed.
for subrep in subreps:
n_sub = subrep(G.sample()).shape[0] # Dimension of sub representation
is_irred, _ = is_complex_irreducible(G, subrep)
if is_irred:
found_irreps.append(subrep)
Qs.append(np.eye(n_sub))
else:
# Find Q_sub such that Q_sub @ subrep[g] @ Q_sub^-1 is block diagonal, with blocks `sub_subrep`
sub_subreps, Q_sub = cplx_isotypic_decomposition(G, subrep)
found_irreps += sub_subreps
Qs.append(Q_sub)
# Sort irreps by dimension.
P, sorted_irreps = sorted_jordan_canonical_form(G, found_irreps)
# If subreps were decomposable, then these get further decomposed with an additional Hermitian matrix such that:
# Q @ rep[g] @ Q^-1 = block_diag[irreps] | Q = (Q_external @ Q_internal)
Q_external = block_diag(*Qs)
Q = P @ Q_external @ Q_internal
# Test isotypic decomposition.
assert np.allclose(Q @ Q.conj().T, np.eye(n)), "Q is not unitary."
for g in G.elements:
g_iso = block_diag(*[irrep[g] if isinstance(irrep, dict) else irrep(g) for irrep in sorted_irreps])
error = np.abs(g_iso - (Q @ rep(g) @ Q.conj().T))
assert np.allclose(error, 0), f"Q @ rep[g] @ Q^-1 != block_diag[irreps[g]], for g={g}. Error \n:{error}"
return sorted_irreps, Q