EMAStats#
- class EMAStats(dim_x, dim_y=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#
Bases:
ModuleExponential Moving Average (EMA) statistics tracker for paired data.
This module tracks running statistics of two input tensors using exponential moving averages without transforming the data. It computes and maintains estimates of:
\(\mu_x\): Mean of input tensor x
\(\mu_y\): Mean of input tensor y
\(\Sigma_{xx}\): Covariance matrix of x
\(\Sigma_{yy}\): Covariance matrix of y
\(\Sigma_{xy}\): Cross-covariance matrix between x and y
Mathematical Formulation:
The exponential moving average update rule for any statistic \(S\) is:
\[S_{\text{running}} = (1 - \alpha) \cdot S_{\text{running}} + \alpha \cdot S_{\text{batch}}\]where \(\alpha\) is the momentum parameter and \(S_{\text{batch}}\) is the statistic computed from the current batch.
Covariance Computation:
For tensors of shape \((N, D)\):
Mean: \(\mu = \mathbb{E}[x]\) computed over batch dimension
Covariance: \(\Sigma = \mathbb{E}[(x - \mu)(x - \mu)^T]\)
Cross-covariance: \(\Sigma_{xy} = \mathbb{E}[(x - \mu_x)(y - \mu_y)^T]\)
- Parameters:
num_features_x – Number of features in input tensor x.
num_features_y – Number of features in input tensor y. If None, uses same as x.
momentum (
float) – Momentum factor for exponential moving average. Must be in (0, 1]. Higher values give more weight to recent batches. Default: 0.1.eps (
float) – Small constant for numerical stability. Default: 1e-6.center_with_running_mean (
bool) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.dim_x (int)
dim_y (int | None)
- Shape:
Input x: \((N, D_x)\) where N is batch size and \(D_x\) is num_features_x.
Input y: \((N, D_y)\) where \(D_y\) is num_features_y.
Output: Same as inputs (data is not transformed).
Example
>>> stats = EMAStats(num_features_x=10, num_features_y=5, momentum=0.1) >>> x = torch.randn(32, 10) # Batch of 32 samples, 10 features >>> y = torch.randn(32, 5) # Batch of 32 samples, 5 features >>> x_out, y_out = stats(x, y) # x_out == x, y_out == y (no transformation) >>> print(stats.mean_x.shape) # torch.Size([10]) >>> print(stats.cov_xy.shape) # torch.Size([10, 5])