EMAStats#

class EMAStats(dim_x, dim_y=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#

Bases: Module

Exponential Moving Average (EMA) statistics tracker for paired data.

This module tracks running statistics of two input tensors using exponential moving averages without transforming the data. It computes and maintains estimates of:

  • \(\mu_x\): Mean of input tensor x

  • \(\mu_y\): Mean of input tensor y

  • \(\Sigma_{xx}\): Covariance matrix of x

  • \(\Sigma_{yy}\): Covariance matrix of y

  • \(\Sigma_{xy}\): Cross-covariance matrix between x and y

Mathematical Formulation:

The exponential moving average update rule for any statistic \(S\) is:

\[S_{\text{running}} = (1 - \alpha) \cdot S_{\text{running}} + \alpha \cdot S_{\text{batch}}\]

where \(\alpha\) is the momentum parameter and \(S_{\text{batch}}\) is the statistic computed from the current batch.

Covariance Computation:

For tensors of shape \((N, D)\):

  • Mean: \(\mu = \mathbb{E}[x]\) computed over batch dimension

  • Covariance: \(\Sigma = \mathbb{E}[(x - \mu)(x - \mu)^T]\)

  • Cross-covariance: \(\Sigma_{xy} = \mathbb{E}[(x - \mu_x)(y - \mu_y)^T]\)

Parameters:
  • num_features_x – Number of features in input tensor x.

  • num_features_y – Number of features in input tensor y. If None, uses same as x.

  • momentum (float) – Momentum factor for exponential moving average. Must be in (0, 1]. Higher values give more weight to recent batches. Default: 0.1.

  • eps (float) – Small constant for numerical stability. Default: 1e-6.

  • center_with_running_mean (bool) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.

  • dim_x (int)

  • dim_y (int | None)

Shape:
  • Input x: \((N, D_x)\) where N is batch size and \(D_x\) is num_features_x.

  • Input y: \((N, D_y)\) where \(D_y\) is num_features_y.

  • Output: Same as inputs (data is not transformed).

running_mean_x#

Running mean of x. Shape: \((D_x,)\).

Type:

Tensor

running_mean_y#

Running mean of y. Shape: \((D_y,)\).

Type:

Tensor

running_cov_xx#

Running covariance of x. Shape: \((D_x, D_x)\).

Type:

Tensor

running_cov_yy#

Running covariance of y. Shape: \((D_y, D_y)\).

Type:

Tensor

running_cov_xy#

Running cross-covariance. Shape: \((D_x, D_y)\).

Type:

Tensor

num_batches_tracked#

Number of batches processed.

Type:

Tensor

Example

>>> stats = EMAStats(num_features_x=10, num_features_y=5, momentum=0.1)
>>> x = torch.randn(32, 10)  # Batch of 32 samples, 10 features
>>> y = torch.randn(32, 5)  # Batch of 32 samples, 5 features
>>> x_out, y_out = stats(x, y)  # x_out == x, y_out == y (no transformation)
>>> print(stats.mean_x.shape)  # torch.Size([10])
>>> print(stats.cov_xy.shape)  # torch.Size([10, 5])
property cov_xx: Tensor#

Running covariance matrix of x.

property cov_xy: Tensor#

Running cross-covariance matrix between x and y.

property cov_yy: Tensor#

Running covariance matrix of y.

extra_repr()[source]#

String representation of module parameters.

Return type:

str

forward(x, y)[source]#

Update running statistics and return inputs unchanged.

Parameters:
  • x (Tensor) – Input tensor x of shape (N, num_features_x).

  • y (Tensor) – Input tensor y of shape (N, num_features_y).

Return type:

tuple[Tensor, Tensor]

Returns:

Tuple (x, y) - inputs are returned unchanged.

property mean_x: Tensor#

Running mean of input x.

property mean_y: Tensor#

Running mean of input y.