eEMAStats#
- class eEMAStats(x_rep, y_rep=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#
Bases:
EMAStatsEquivariant EMA statistics on symmetric vector spaces.
Let \(\mathbf{X}:\Omega\to\mathcal{X}\) and \(\mathbf{Y}:\Omega\to\mathcal{Y}\) be two \(\mathbb{G}\)-invariant random variables taking values in the symmetric vector spaces \((\mathcal{X}, \rho_{\mathcal{X}})\) and \((\mathcal{Y}, \rho_{\mathcal{Y}})\). This module tracks exponential moving averages of their symmetry-constrained first and second moments.
The running means are constrained to be \(\mathbb{G}\)-invariant,
\[\rho_{\mathcal{X}}(g)\boldsymbol{\mu}_x = \boldsymbol{\mu}_x, \qquad \rho_{\mathcal{Y}}(g)\boldsymbol{\mu}_y = \boldsymbol{\mu}_y, \qquad \forall g\in\mathbb{G},\]and the running covariance operators commute with the corresponding group actions,
\[\rho_{\mathcal{X}}(g)\mathbf{C}_{xx} = \mathbf{C}_{xx}\rho_{\mathcal{X}}(g), \qquad \rho_{\mathcal{Y}}(g)\mathbf{C}_{yy} = \mathbf{C}_{yy}\rho_{\mathcal{Y}}(g), \qquad \rho_{\mathcal{X}}(g)\mathbf{C}_{xy} = \mathbf{C}_{xy}\rho_{\mathcal{Y}}(g), \qquad \forall g\in\mathbb{G}.\]Equivalently,
\[\mathbf{C}_{xx} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{X}}), \qquad \mathbf{C}_{yy} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{Y}}), \qquad \mathbf{C}_{xy} \in \mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{X}}).\]For any tracked statistic \(\mathbf{S}\), the running update is
\[\mathbf{S}_{t} = (1-\alpha)\mathbf{S}_{t-1} + \alpha\,\mathbf{S}^{\text{batch}}_{t},\]where \(\alpha\in(0,1]\) is the momentum and \(\mathbf{S}^{\text{batch}}_{t}\) is the statistic computed from the current batch.
- Parameters:
x_rep (
Representation) – Representation defining input x’s group action.y_rep (
Representation) – Representation defining input y’s group action. If None, usesx_rep.momentum (float, optional) – Momentum factor for exponential moving average. Default: 0.1.
eps (float, optional) – Small constant for numerical stability. Default: 1e-6.
center_with_running_mean (bool, optional) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.
- Shape:
Input x:
(N, D_x)Input y:
(N, D_y)Output: Same as inputs (data is not transformed)
Example
>>> stats = eEMAStats(x_rep=rep_x, y_rep=rep_y, momentum=0.1) >>> x_out, y_out = stats(x, y) # Same tensors, updated statistics >>> mu_x = stats.mean_x >>> mu_y = stats.mean_y >>> C_xx = stats.cov_xx >>> C_xy = stats.cov_xy >>> print(mu_x.shape, mu_y.shape, C_xx.shape, C_xy.shape)
Note
Running covariance buffers are stored internally in the degrees of freedom of \(\mathrm{Hom}_{\mathbb{G}}\) rather than as dense matrices. As in
EMAStats, previously tracked means \(\boldsymbol{\mu}_{x,t-1}\), \(\boldsymbol{\mu}_{y,t-1}\) and covariance coefficients \(\boldsymbol{\theta}_{xx,t-1}\), \(\boldsymbol{\theta}_{yy,t-1}\), \(\boldsymbol{\theta}_{xy,t-1}\) are always replaced by their detached counterparts before reuse in covariance centering or in the EMA update. In training mode, the DoF statistics are updated directly using\[\boldsymbol{\theta}_{t} = (1-\alpha)\,\operatorname{detach}(\boldsymbol{\theta}_{t-1}) + \alpha\,\boldsymbol{\theta}^{\text{batch}}_{t},\]so gradients never propagate through older batches. The dense covariance matrices exposed by
cov_xx,cov_yy, andcov_xyare expanded from the current DoF coefficients on demand. In eval mode, they are expanded lazily and cached until the module changes mode, device, dtype, or reloads from a checkpoint.- property cov_xx: Tensor#
Running covariance operator in \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{X}}, \rho_{\mathcal{X}})\).
Equivalently, the returned matrix \(\mathbf{C}_{xx}\) satisfies
\[\rho_{\mathcal{X}}(g)\mathbf{C}_{xx} = \mathbf{C}_{xx}\rho_{\mathcal{X}}(g), \qquad \forall g\in\mathbb{G}.\]- Shape:
\((D_x, D_x)\).
- property cov_xy: Tensor#
Running cross-covariance operator.
The returned matrix \(\mathbf{C}_{xy}\) maps coordinates in \(\mathcal{Y}\) to coordinates in \(\mathcal{X}\). It belongs to \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{X}})\) and satisfies the intertwining constraint
\[\rho_{\mathcal{X}}(g)\mathbf{C}_{xy} = \mathbf{C}_{xy}\rho_{\mathcal{Y}}(g), \qquad \forall g\in\mathbb{G}.\]- Shape:
\((D_x, D_y)\).
- property cov_yy: Tensor#
Running covariance operator in \(\mathrm{Hom}_{\mathbb{G}}(\rho_{\mathcal{Y}}, \rho_{\mathcal{Y}})\).
Equivalently, the returned matrix \(\mathbf{C}_{yy}\) satisfies
\[\rho_{\mathcal{Y}}(g)\mathbf{C}_{yy} = \mathbf{C}_{yy}\rho_{\mathcal{Y}}(g), \qquad \forall g\in\mathbb{G}.\]- Shape:
\((D_y, D_y)\).
- forward(x, y)[source]#
Update equivariant running statistics and return the inputs unchanged.
- Parameters:
- Return type:
- Returns:
Tuple
(x, y). The activations are not transformed; only the running invariant means and equivariant covariance operators are updated.
- Shape:
x: \((N, D_x)\).
y: \((N, D_y)\).
Output:
((N, D_x), (N, D_y)).