CommutingConstraint#

class CommutingConstraint(in_rep, out_rep, basis_expansion='isotypic_expansion')[source]#

Bases: Module

Orthogonal projection onto \(\operatorname{Hom}_{\mathbb{G}}(\rho_{\text{in}},\rho_{\text{out}})\).

For a dense weight \(\mathbf{W}\in\mathbb{R}^{D_{\text{out}}\times D_{\text{in}}}\), this module returns \(\Pi_{\mathrm{Hom}}(\mathbf{W})\), the Frobenius-orthogonal projection onto

\[\operatorname{Hom}_{\mathbb{G}}(\rho_{\text{in}},\rho_{\text{out}}) = \{\mathbf{A}: \rho_{\text{out}}(g)\mathbf{A}=\mathbf{A}\rho_{\text{in}}(g),\ \forall g\in\mathbb{G}\}.\]

The basis and projection are handled by GroupHomomorphismBasis, using the isotypic decomposition (isotypic_decomp_rep()) blockwise.

Parameters:
  • in_rep (Representation) – Input representation \(\rho_{\text{in}}\) of size in_rep.size.

  • out_rep (Representation) – Output representation \(\rho_{\text{out}}\) of size out_rep.size.

  • basis_expansion (str, optional) – Strategy used to realize the basis ("memory_heavy" or "isotypic_expansion").

Note

Runtime behavior depends on mode. In training mode (model.train()), the projection is recomputed each forward pass. In inference mode (model.eval()), the projected matrix is cached for the same unchanged input tensor (same object identity and version counter), which is faster. With the cache active, as a parametrization of Linear, the forward path is equivalent to a symmetry-agnostic standard linear layer with a fixed projected dense weight.

homo_basis#

Basis generator carrying the isotypic decomposition and block metadata for \(\operatorname{Hom}_\mathbb{G}(\rho_{\text{in}}, \rho_{\text{out}})\).

Type:

GroupHomomorphismBasis

in_rep / out_rep

Cached references to the isotypic versions of the supplied representations.

Type:

Representation

forward(W)[source]#

Project \(\mathbf{W}\) onto the space of equivariant linear maps.

Parameters:

W (Tensor) – Dense matrix \(\mathbf{W}\in\mathbb{R}^{D_{\mathrm{out}}\times D_{\mathrm{in}}}\).

Returns:

Frobenius-orthogonal projection \(\Pi_{\mathrm{Hom}}(\mathbf{W})\), which satisfies \(\rho_{\mathrm{out}}(g)\Pi_{\mathrm{Hom}}(\mathbf{W})=\Pi_{\mathrm{Hom}}(\mathbf{W})\rho_{\mathrm{in}}(g)\) for all \(g\in\mathbb{G}\).

Return type:

Tensor

invalidate_cache()[source]#

Clear cached projection so it is recomputed on next use.

Return type:

None

load_state_dict(state_dict, strict=True)[source]#

Load parameters and clear cached projected weight.

Parameters:

strict (bool)

right_inverse(tensor)[source]#

Return a pre-image for the parametrization (identity for now).

Return type:

Tensor

Parameters:

tensor (Tensor)