eEMAStats#

class eEMAStats(x_rep, y_rep=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#

Bases: EMAStats

Equivariant version of EMAStats using group-theoretic symmetry-aware statistics.

This module extends EMAStats to work with equivariant data by computing statistics that respect the symmetry structure defined by group representations. It uses symmetry-aware mean and covariance computations from symm_learning.stats.

Parameters:
  • x_rep (Representation) – Representation defining input x’s group action.

  • y_rep (Representation) – Representation defining input y’s group action. If None, uses x_rep.

  • momentum (float, optional) – Momentum factor for exponential moving average. Default: 0.1.

  • eps (float, optional) – Small constant for numerical stability. Default: 1e-6.

  • center_with_running_mean (bool, optional) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.

Shape:
  • Input x: (N, D_x)

  • Input y: (N, D_y)

  • Output: Same as inputs (data is not transformed)

Example

>>> stats = eEMAStats(x_rep=rep_x, y_rep=rep_y, momentum=0.1)
>>> x_out, y_out = stats(x, y)  # Same tensors, updated statistics
>>> standard_stats = stats.export()  # Export to standard EMAStats
export()[source]#

Export to a standard EMAStats layer.

Return type:

EMAStats

forward(x, y)[source]#

Update running statistics and return inputs unchanged.

Parameters:
  • x (Tensor) – Input tensor x with representation x_rep.

  • y (Tensor) – Input tensor y with representation y_rep.

Return type:

tuple[Tensor, Tensor]

Returns:

Tuple (x, y) - inputs are returned unchanged.