eEMAStats#
- class eEMAStats(x_rep, y_rep=None, momentum=0.1, eps=1e-06, center_with_running_mean=True)[source]#
Bases:
EMAStatsEquivariant version of EMAStats using group-theoretic symmetry-aware statistics.
This module extends
EMAStatsto work with equivariant data by computing statistics that respect the symmetry structure defined by group representations. It uses symmetry-aware mean and covariance computations fromsymm_learning.stats.- Parameters:
x_rep (
Representation) – Representation defining input x’s group action.y_rep (
Representation) – Representation defining input y’s group action. If None, usesx_rep.momentum (float, optional) – Momentum factor for exponential moving average. Default: 0.1.
eps (float, optional) – Small constant for numerical stability. Default: 1e-6.
center_with_running_mean (bool, optional) – If True, center covariance computation using running means instead of batch means (except for first batch). Default: True.
- Shape:
Input x:
(N, D_x)Input y:
(N, D_y)Output: Same as inputs (data is not transformed)
Example
>>> stats = eEMAStats(x_rep=rep_x, y_rep=rep_y, momentum=0.1) >>> x_out, y_out = stats(x, y) # Same tensors, updated statistics >>> standard_stats = stats.export() # Export to standard EMAStats