project_in_isobasis#

project_in_isobasis(W, rep_x, rep_y, tensor_cache=None)[source]#

Project W onto \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})\) in isotypic coordinates

This is an utility function handling the projection of a dense linear map \(\mathbf{W}:\mathcal{X}\to\mathcal{Y}\) onto the space of \(\mathbb{G}\)-equivariant linear maps, \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})\). Returning useful intermediate tensors, for algebraic and statistical computations.

Let \(\mathbf{Q}_{\mathcal{X}}\) and \(\mathbf{Q}_{\mathcal{Y}}\) be the orthogonal change-of-basis exposing the isotypic decomposition of the symmetric vector spaces \((\mathcal{X}, \rho_{\mathcal{X}})\) and \((\mathcal{Y}, \rho_{\mathcal{Y}})\) (see Isotypic Decomposition). In this basis, any \(\mathbb{G}\)-equivariant linear map \(\mathbf{A} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{Y}})\) decomposes in block diagonal form as described in Leveraging the structure of Equivariant Linear maps:

\[\mathbf{A}_{\mathrm{iso}} = \mathbf{Q}_{\mathcal{Y}}^\top \mathbf{A} \mathbf{Q}_{\mathcal{X}} = \bigoplus_{k\in[1,n_{\text{iso}}]} \mathbf{A}^{(k)}, \qquad \mathbf{A}^{(k)} = \sum_{s=1}^{S_k}\mathbf{\Theta}^{(k)}_s \otimes \mathbf{\Psi}^{(k)}_s.\]

Where \(\{\mathbf{\Psi}^{(k)}_s \in \mathbb{R}^{d_k \times d_k}\}_{s=1}^{S_k}\) is a fixed basis of the endomorphism \(\mathrm{End}_{\mathbb{G}}(\hat{\rho}_k)\), and \(\mathbf{\Theta}^{(k)}_s \in \mathbb{R}^{m_k^{\mathcal{Y}} \times m_k^{\mathcal{X}}}\) are the free parameters (or degrees of freedom) of the equivariant map, serving as basis expandion coefficients.

Consequently, this function projects the input map \(\mathbf{W}\) to isotypic coordinates: \(\mathbf{W}_{\mathrm{iso}} = \mathbf{Q}_{\mathcal{Y}}^T \mathbf{W} \mathbf{Q}_{\mathcal{X}}\), and computes the coefficients \(\mathbf{\Theta}^{(k)}\) of the projected map in each isotypic block:

\[\mathbf{\Theta}^{(k)}_{o,i,s} = \frac{\langle \mathbf{W}^{(k)}_{o,i}, \mathbf{\Psi}^{(k)}_s\rangle_F} {\lVert \mathbf{\Psi}^{(k)}_s \rVert_F^2}, \qquad \forall \; k \in [1, n_{\text{iso}}], o \in [1, m_k^{\mathcal{Y}}], i \in [1, m_k^{\mathcal{X}}], s \in [1, S_k].\]
Parameters:
  • W (Tensor) – Dense linear map (or batch of maps) from \(\mathcal{X}\) to \(\mathcal{Y}\) with shape \((..., D_y, D_x)\).

  • rep_x (Representation) – Input representation \(\rho_{\mathcal{X}}\).

  • rep_y (Representation) – Output representation \(\rho_{\mathcal{Y}}\).

  • tensor_cache (IsotypicTensorCache, optional) – Optional tensor cache override containing Q_out, Q_in_inv, endo_basis_flat, and endo_norm_sq keyed by irrep id. When provided, these tensors are reused directly, without copying, and therefore must already have dtype and device compatible with W. In particular, the function reads tensor_cache.Q_out, tensor_cache.Q_in_inv, tensor_cache.endo_basis_flat[irrep_id], and tensor_cache.endo_norm_sq[irrep_id] for each shared irrep.

Returns:

Tuple with entries:

  • Q_out: matrix \(\mathbf{Q}_{\mathcal{Y}}\). If tensor_cache is None, this is a tensor copy with dtype and device matching W. Otherwise, tensor_cache.Q_out is returned directly.

  • Q_in_inv: matrix \(\mathbf{Q}_{\mathcal{X}}^{\top}\). If tensor_cache is None, this is a tensor copy with dtype and device matching W. Otherwise, tensor_cache.Q_in_inv is returned directly.

  • projection_iso_spaces: dictionary keyed by irrep identifier irrep_id. Each projection_iso_spaces[irrep_id] is an IsoSpaceProjection containing:

    • coeff (Tensor): projection coefficients \(\mathbf{\Theta}^{(k)}\) of shape \((..., m_k^{\mathcal{Y}} m_k^{\mathcal{X}}, S_k)\).

    • endo_basis_flat (Tensor): stacked flattened endomorphism basis \(\operatorname{flat}(\mathbf{\Psi}^{(k)}) := [\operatorname{flat}(\mathbf{\Psi}^{(k)}_1), \ldots, \operatorname{flat}(\mathbf{\Psi}^{(k)}_{S_k})]^\top \in \mathbb{R}^{S_k \times d_k^2}\). If tensor_cache is provided, this entry is tensor_cache.endo_basis_flat[irrep_id] returned directly.

    • m_out (int): output multiplicity \(m_k^{\mathcal{Y}}\).

    • m_in (int): input multiplicity \(m_k^{\mathcal{X}}\).

    • d_k (int): irrep dimension \(d_k = \dim(\hat{\rho}_k)\).

    • out_slice (slice): slice locating \(\mathbf{W}^{(k)}\) inside \(\mathbf{W}_{\mathrm{iso}}\).

    • in_slice (slice): slice locating \(\mathbf{W}^{(k)}\) inside \(\mathbf{W}_{\mathrm{iso}}\).

Return type:

tuple[torch.Tensor, torch.Tensor, dict[tuple[int, …], IsoSpaceProjection]]

Shape:
  • W: \((..., D_y, D_x)\).

  • Q_out: \((D_y, D_y)\).

  • Q_in_inv: \((D_x, D_x)\).