EMAStats#
- class EMAStats(dim_x, dim_y=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#
Bases:
eModuleExponential moving averages of first and second moments.
Let \(\mathbf{X}:\Omega\to\mathcal{X}\) and \(\mathbf{Y}:\Omega\to\mathcal{Y}\) be two random variables. This module tracks exponential moving averages of their batch means \(\boldsymbol{\mu}_x\), \(\boldsymbol{\mu}_y\), self-covariances \(\mathbf{C}_{xx}\), \(\mathbf{C}_{yy}\), and cross-covariance \(\mathbf{C}_{xy}\).
For any tracked statistic \(\mathbf{S}\), the running update is
\[\mathbf{S}_{t} = (1-\alpha)\mathbf{S}_{t-1} + \alpha\,\mathbf{S}^{\text{batch}}_{t},\]where \(\alpha\in(0,1]\) is the momentum and \(\mathbf{S}^{\text{batch}}_{t}\) is the statistic computed from the current batch.
- Parameters:
num_features_x – Number of features in input tensor x.
num_features_y – Number of features in input tensor y. If None, uses same as x.
momentum (
float) – Momentum factor for exponential moving average. Must be in (0, 1]. Higher values give more weight to recent batches. Default: 0.1.eps (
float) – Small constant for numerical stability. Default: 1e-6.center_with_running_mean (
bool) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.dim_x (int)
dim_y (int | None)
- Shape:
Input x: \((N, D_x)\) where N is batch size and \(D_x\) is num_features_x.
Input y: \((N, D_y)\) where \(D_y\) is num_features_y.
Output: Same as inputs (data is not transformed).
Example
>>> stats = EMAStats(dim_x=10, dim_y=5, momentum=0.1) >>> x = torch.randn(32, 10) >>> y = torch.randn(32, 5) >>> x_out, y_out = stats(x, y) >>> mu_x = stats.mean_x >>> C_xy = stats.cov_xy >>> print(mu_x.shape, C_xy.shape)
Note
Whenever previously tracked statistics are reused, the historical means \(\boldsymbol{\mu}_{x,t-1}\), \(\boldsymbol{\mu}_{y,t-1}\) and covariance operators \(\mathbf{C}_{xx,t-1}\), \(\mathbf{C}_{yy,t-1}\), \(\mathbf{C}_{xy,t-1}\) are first detached from the autograd graph. In particular, the centering mean and the previous EMA state in the update
\[\mathbf{S}_{t} = (1-\alpha)\,\operatorname{detach}(\mathbf{S}_{t-1}) + \alpha\,\mathbf{S}^{\text{batch}}_{t}\]are treated as constants with respect to the current optimization step. This truncates the computation graph across batches, preventing gradients from propagating through historical running state while keeping the current batch contribution differentiable.
- forward(x, y)[source]#
Update running statistics and return inputs unchanged.
- Parameters:
- Return type:
- Returns:
Tuple (x, y) - inputs are returned unchanged.
Note
For batches after the first one, the EMA update uses \(\operatorname{detach}(\boldsymbol{\mu}_{x,t-1})\), \(\operatorname{detach}(\boldsymbol{\mu}_{y,t-1})\), \(\operatorname{detach}(\mathbf{C}_{xx,t-1})\), \(\operatorname{detach}(\mathbf{C}_{yy,t-1})\), and \(\operatorname{detach}(\mathbf{C}_{xy,t-1})\) before mixing them with the current batch statistics. This preserves differentiability with respect to the current batch while preventing the running buffers from carrying a growing cross-batch autograd graph.