eEMAStats#

class eEMAStats(x_rep, y_rep=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#

Bases: EMAStats

Equivariant EMA statistics on symmetric vector spaces.

Let \(\mathbf{X}:\Omega\to\mathcal{X}\) and \(\mathbf{Y}:\Omega\to\mathcal{Y}\) be two \(\mathbb{G}\)-invariant random variables taking values in the symmetric vector spaces \((\mathcal{X}, \rho_{\mathcal{X}})\) and \((\mathcal{Y}, \rho_{\mathcal{Y}})\). This module tracks exponential moving averages of their symmetry-constrained first and second moments.

The running means are constrained to be \(\mathbb{G}\)-invariant,

\[\rho_{\mathcal{X}}(g)\boldsymbol{\mu}_x = \boldsymbol{\mu}_x, \qquad \rho_{\mathcal{Y}}(g)\boldsymbol{\mu}_y = \boldsymbol{\mu}_y, \qquad \forall g\in\mathbb{G},\]

and the running covariance operators commute with the corresponding group actions,

\[\rho_{\mathcal{X}}(g)\mathbf{C}_{xx} = \mathbf{C}_{xx}\rho_{\mathcal{X}}(g), \qquad \rho_{\mathcal{Y}}(g)\mathbf{C}_{yy} = \mathbf{C}_{yy}\rho_{\mathcal{Y}}(g), \qquad \rho_{\mathcal{X}}(g)\mathbf{C}_{xy} = \mathbf{C}_{xy}\rho_{\mathcal{Y}}(g), \qquad \forall g\in\mathbb{G}.\]

Equivalently,

\[\mathbf{C}_{xx} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{X}}), \qquad \mathbf{C}_{yy} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{Y}}), \qquad \mathbf{C}_{xy} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{X}}).\]

For any tracked statistic \(\mathbf{S}\), the running update is

\[\mathbf{S}_{t} = (1-\alpha)\mathbf{S}_{t-1} + \alpha\,\mathbf{S}^{\text{batch}}_{t},\]

where \(\alpha\in(0,1]\) is the momentum and \(\mathbf{S}^{\text{batch}}_{t}\) is the statistic computed from the current batch.

Parameters:
  • x_rep (Representation) – Representation defining input x’s group action.

  • y_rep (Representation) – Representation defining input y’s group action. If None, uses x_rep.

  • momentum (float, optional) – Momentum factor for exponential moving average. Default: 0.1.

  • eps (float, optional) – Small constant for numerical stability. Default: 1e-6.

  • center_with_running_mean (bool, optional) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.

Shape:
  • Input x: (N, D_x)

  • Input y: (N, D_y)

  • Output: Same as inputs (data is not transformed)

Example

>>> stats = eEMAStats(x_rep=rep_x, y_rep=rep_y, momentum=0.1)
>>> x_out, y_out = stats(x, y)  # Same tensors, updated statistics
>>> mu_x = stats.mean_x
>>> mu_y = stats.mean_y
>>> C_xx = stats.cov_xx
>>> C_xy = stats.cov_xy
>>> print(mu_x.shape, mu_y.shape, C_xx.shape, C_xy.shape)

Note

Running covariance buffers are stored internally in the degrees of freedom of \(\mathrm{Hom}_{\mathbb{G}}\) rather than as dense matrices. As in EMAStats, previously tracked means \(\boldsymbol{\mu}_{x,t-1}\), \(\boldsymbol{\mu}_{y,t-1}\) and covariance coefficients \(\boldsymbol{\theta}_{xx,t-1}\), \(\boldsymbol{\theta}_{yy,t-1}\), \(\boldsymbol{\theta}_{xy,t-1}\) are always replaced by their detached counterparts before reuse in covariance centering or in the EMA update. In training mode, the DoF statistics are updated directly using

\[\boldsymbol{\theta}_{t} = (1-\alpha)\,\operatorname{detach}(\boldsymbol{\theta}_{t-1}) + \alpha\,\boldsymbol{\theta}^{\text{batch}}_{t},\]

so gradients never propagate through older batches. The dense covariance matrices exposed by cov_xx, cov_yy, and cov_xy are expanded from the current DoF coefficients on demand. In eval mode, they are expanded lazily and cached until the module changes mode, device, dtype, or reloads from a checkpoint.

property cov_xx: Tensor#

Running covariance operator in \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{X}})\).

Equivalently, the returned matrix \(\mathbf{C}_{xx}\) satisfies

\[\rho_{\mathcal{X}}(g)\mathbf{C}_{xx} = \mathbf{C}_{xx}\rho_{\mathcal{X}}(g), \qquad \forall g\in\mathbb{G}.\]
Shape:

\((D_x, D_x)\).

property cov_xy: Tensor#

Running cross-covariance operator.

The returned matrix \(\mathbf{C}_{xy}\) maps coordinates in \(\mathcal{Y}\) to coordinates in \(\mathcal{X}\). It belongs to \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{X}})\) and satisfies the intertwining constraint

\[\rho_{\mathcal{X}}(g)\mathbf{C}_{xy} = \mathbf{C}_{xy}\rho_{\mathcal{Y}}(g), \qquad \forall g\in\mathbb{G}.\]
Shape:

\((D_x, D_y)\).

property cov_yy: Tensor#

Running covariance operator in \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{Y}})\).

Equivalently, the returned matrix \(\mathbf{C}_{yy}\) satisfies

\[\rho_{\mathcal{Y}}(g)\mathbf{C}_{yy} = \mathbf{C}_{yy}\rho_{\mathcal{Y}}(g), \qquad \forall g\in\mathbb{G}.\]
Shape:

\((D_y, D_y)\).

forward(x, y)[source]#

Update equivariant running statistics and return the inputs unchanged.

Parameters:
  • x (Tensor) – Samples in \(\mathcal{X}\) transforming according to \(\rho_{\mathcal{X}}\).

  • y (Tensor) – Samples in \(\mathcal{Y}\) transforming according to \(\rho_{\mathcal{Y}}\).

Return type:

tuple[Tensor, Tensor]

Returns:

Tuple (x, y). The activations are not transformed; only the running invariant means and equivariant covariance operators are updated.

Shape:
  • x: \((N, D_x)\).

  • y: \((N, D_y)\).

  • Output: ((N, D_x), (N, D_y)).

invalidate_cache()[source]#

Clear cached dense covariance expansions.

Return type:

None

property mean_x: Tensor#

Running invariant mean in \(\mathcal{X}^{\text{inv}}\).

This vector satisfies

\[\rho_{\mathcal{X}}(g)\boldsymbol{\mu}_x = \boldsymbol{\mu}_x, \qquad \forall g\in\mathbb{G}.\]
property mean_y: Tensor#

Running invariant mean in \(\mathcal{Y}^{\text{inv}}\).

This vector satisfies

\[\rho_{\mathcal{Y}}(g)\boldsymbol{\mu}_y = \boldsymbol{\mu}_y, \qquad \forall g\in\mathbb{G}.\]