from __future__ import annotations
import torch
from escnn.nn import EquivariantModule, FieldType, GeometricTensor
from torch import batch_norm
import symm_learning
import symm_learning.stats
from symm_learning.nn import eAffine
[docs]
class eBatchNorm1d(EquivariantModule):
r"""Applies Batch Normalization over a 2D or 3D symmetric input :class:`escnn.nn.GeometricTensor`.
Method described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated **using symmetry-aware estimates** (see
:func:`~symm_learning.stats.var_mean`) over the mini-batches and :math:`\gamma` and :math:`\beta` are
the scale and bias vectors of a :class:`eAffine`, which ensures that the affine transformation is
symmetry-preserving. By default, the elements of :math:`\gamma` are initialized to 1 and the elements
of :math:`\beta` are set to 0.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
If input tensor is of shape :math:`(N, C, L)`, the implementation of this module
computes a unique mean and variance for each feature or channel :math:`C` and applies it to
all the elements in the sequence length :math:`L`.
Args:
input_type: the :class:`escnn.nn.FieldType` of the input geometric tensor.
The output type is the same as the input type.
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
:math:`C` is the number of features or channels, and :math:`L` is the sequence length
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
"""
def __init__(
self,
in_type: FieldType,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super().__init__()
self.in_type = in_type
self.out_type = in_type
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self._rep_x = in_type.representation
if self.track_running_stats:
self.register_buffer("running_mean", torch.zeros(in_type.size))
self.register_buffer("running_var", torch.ones(in_type.size))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
if self.affine:
self.affine_transform = eAffine(
in_type=in_type,
bias=True,
)
[docs]
def forward(self, x: GeometricTensor): # noqa: D102
assert x.type == self.in_type, "Input type does not match the expected input type."
var_batch, mean_batch = symm_learning.stats.var_mean(x.tensor, rep_x=self._rep_x)
if self.track_running_stats:
if self.training:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean_batch
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var_batch
self.num_batches_tracked += 1
mean, var = self.running_mean, self.running_var
else:
mean, var = mean_batch, var_batch
mean, var = mean[..., None], var[..., None] if x.tensor.ndim == 3 else (mean, var)
y = (x.tensor - mean) / torch.sqrt(var + self.eps)
y = self.affine_transform(self.out_type(y)) if self.affine else self.out_type(y)
return y
[docs]
def evaluate_output_shape(self, input_shape): # noqa: D102
return input_shape
[docs]
def check_equivariance(self, atol=1e-5, rtol=1e-5):
"""Check the equivariance of the convolution layer."""
import numpy as np
was_training = self.training
time = 1
batch_size = 50
self.train()
# Compute some empirical statistics
for _ in range(5):
x = torch.randn(batch_size, self.in_type.size, time)
x = self.in_type(x)
_ = self(x)
self.eval()
x_batch = torch.randn(batch_size, self.in_type.size, time)
x_batch = self.in_type(x_batch)
for i in range(10):
g = self.in_type.representation.group.sample()
if g == self.in_type.representation.group.identity:
i -= 1
continue
gx_batch = x_batch.transform(g)
var, mean = symm_learning.stats.var_mean(x_batch.tensor, rep_x=self.in_type.representation)
g_var, g_mean = symm_learning.stats.var_mean(gx_batch.tensor, rep_x=self.in_type.representation)
assert torch.allclose(mean, g_mean, atol=1e-4, rtol=1e-4), f"Mean {mean} != {g_mean}"
assert torch.allclose(var, g_var, atol=1e-4, rtol=1e-4), f"Var {var} != {g_var}"
y = self(x_batch)
g_y = self(gx_batch)
g_y_gt = y.transform(g)
assert torch.allclose(g_y.tensor, g_y_gt.tensor, atol=1e-5, rtol=1e-5), (
f"Output {g_y.tensor} does not match the expected output {g_y_gt.tensor} for group element {g}"
)
self.train(was_training)
return None
[docs]
def export(self) -> torch.nn.BatchNorm1d:
"""Export the layer to a standard PyTorch BatchNorm1d layer."""
bn = torch.nn.BatchNorm1d(
num_features=self.in_type.size,
eps=self.eps,
momentum=self.momentum,
affine=self.affine,
track_running_stats=self.track_running_stats,
)
if self.affine:
bn.weight.data = self.affine_transform.scale.clone()
bn.bias.data = self.affine_transform.bias.clone()
if self.track_running_stats:
bn.running_mean.data = self.running_mean.clone()
bn.running_var.data = self.running_var.clone()
bn.num_batches_tracked.data = self.num_batches_tracked.clone()
else:
bn.running_mean = None
bn.running_var = None
bn.train(False)
bn.eval()
return bn
[docs]
class DataNorm(torch.nn.Module):
r"""Applies data normalization to a 2D or 3D tensor.
This module standardizes input data by centering (subtracting the mean) and optionally
scaling (dividing by the standard deviation). The module supports multiple modes of
operation controlled by its configuration parameters.
**Mathematical Formulation:**
The normalization is applied element-wise as:
.. math::
y = \begin{cases}
x - \mu & \text{if } \texttt{only centering} = \text{True} \\
\frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} & \text{otherwise}
\end{cases}
where :math:`\mu` is the mean, :math:`\sigma^2` is the variance, and :math:`\epsilon`
is a small constant for numerical stability.
**Mode of Operation:**
This layer features a non-standard behavior during training. Unlike typical
normalization layers (e.g., ``torch.nn.BatchNorm1d``) that normalize using
batch statistics, this layer normalizes the data using the running statistics
that have been updated with the current batch's information.
- **During training**:
1. Batch statistics (:math:`\mu_{\text{batch}}`, :math:`\sigma^2_{\text{batch}}`) are computed from the input.
2. Running statistics (:math:`\mu_{\text{run}}`, :math:`\sigma^2_{\text{run}}`) are updated using exponential
moving average:
:math:`\text{running stat} = (1-\alpha) \cdot \text{running stat} + \alpha \cdot \text{batch stat}`
3. The input data is then normalized using these newly updated running statistics.
This allows the loss to be dependent on the running statistics, with gradients flowing
back through the batch statistics component of the update, but not into the historical
state of the running statistics from previous steps.
- **During evaluation**: Uses the final stored running statistics for normalization.
**Special case**: When ``momentum=1.0``, the layer effectively uses batch statistics for
normalization, becoming equivalent to a `torch.nn.BatchNorm1d` layer with `track_running_stats=False`.
Args:
num_features (int): Number of features or channels in the input tensor.
eps (float, optional): Small constant added to the denominator for numerical
stability. Only used when ``only_centering=False``. Default: ``1e-6``.
only_centering (bool, optional): If ``True``, only centers the data (subtracts mean)
without scaling by standard deviation. Default: ``False``.
compute_cov (bool, optional): If ``True``, computes and tracks the full covariance
matrix in addition to mean and variance. Accessible via the ``cov`` property.
Default: ``False``.
momentum (float, optional): Momentum factor for exponential moving average of
running statistics. Must be greater than 0. Setting to ``1.0`` effectively
uses only batch statistics. Default: ``1.0``.
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)` where:
- :math:`N` is the batch size
- :math:`C` is the number of features (must equal ``num_features``)
- :math:`L` is the sequence length (optional, for 3D inputs)
- Output: Same shape as input
Attributes:
running_mean (torch.Tensor): Running average of input means. Shape: ``(num_features,)``.
running_var (torch.Tensor): Running average of input variances. Shape: ``(num_features,)``.
running_cov (torch.Tensor): Running average of input covariance matrix.
Shape: ``(num_features, num_features)``. Only available when ``compute_cov=True``.
num_batches_tracked (torch.Tensor): Number of batches processed during training.
Note:
When using 3D inputs :math:`(N, C, L)`, statistics are computed over both the batch
dimension :math:`N` and sequence dimension :math:`L`, treating each feature channel
independently.
"""
def __init__(
self,
num_features: int,
eps: float = 1e-6,
only_centering: bool = False,
compute_cov: bool = False,
momentum: float = 1.0,
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.only_centering = only_centering
self.compute_cov = compute_cov
if momentum <= 0.0:
raise ValueError(f"momentum must be greater than 0, got {momentum}")
self.momentum = momentum
# Initialize running statistics buffers
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
if compute_cov:
self.register_buffer("running_cov", torch.eye(num_features))
self._last_cov = None
def _compute_batch_stats(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute batch statistics (mean and var). Can be overridden for equivariant versions."""
dims = [0] + ([2] if x.ndim > 2 else [])
batch_mean = x.mean(dim=dims)
batch_var = torch.ones_like(batch_mean) if self.only_centering else x.var(dim=dims, unbiased=False)
return batch_mean, batch_var
def _compute_batch_cov(self, x: torch.Tensor) -> torch.Tensor:
"""Compute batch covariance. Can be overridden for equivariant versions."""
x_flat = x.permute(0, 2, 1).reshape(-1, self.num_features) if x.ndim == 3 else x
x_centered = x_flat - x_flat.mean(dim=0, keepdim=True)
batch_cov = torch.mm(x_centered.T, x_centered) / (x_centered.shape[0] - 1)
return batch_cov
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the normalization to the input tensor."""
assert x.shape[1] == self.num_features
# --- Determine statistics for normalization ---
if self.training:
# Training mode: Use running stats for normalization, but in a way
# that gradients flow back through the current batch's statistics.
batch_mean, batch_var = self._compute_batch_stats(x)
# 2. Calculate the new running statistics for the CURRENT step.
if self.num_batches_tracked == 0:
# For the very first batch, initialize running stats with batch stats
self.running_mean = batch_mean
self.running_var = batch_var
else:
# Detach prev running stats to prevent gradients from flowing into the previous iteration's state.
self.running_mean = (1 - self.momentum) * self.running_mean.detach() + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var.detach() + self.momentum * batch_var
self.num_batches_tracked += 1
# --- Covariance Computation (if enabled) ---
if self.compute_cov:
batch_cov = self._compute_batch_cov(x)
if self.training:
if self.num_batches_tracked == 0:
self.running_cov = batch_cov
else:
self.running_cov = (1 - self.momentum) * self.running_cov.detach() + self.momentum * batch_cov
# --- Apply Normalization ---
mean = self.running_mean
std = torch.sqrt(self.running_var + self.eps) if not self.only_centering else torch.ones_like(self.running_mean)
# Reshape for broadcasting
if x.ndim == 3:
mean = mean.view(1, self.num_features, 1)
std = std.view(1, self.num_features, 1)
else:
mean = mean.view(1, self.num_features)
std = std.view(1, self.num_features)
return (x - mean) / std
@property
def mean(self) -> torch.Tensor:
"""Return the current mean estimate."""
return self.running_mean
@property
def var(self) -> torch.Tensor:
"""Return the current variance estimate."""
return self.running_var
@property
def std(self) -> torch.Tensor:
"""Return the current std estimate (computed from variance)."""
return torch.sqrt(self.running_var)
@property
def cov(self) -> torch.Tensor:
"""Return the current covariance matrix estimate."""
if not self.compute_cov:
raise RuntimeError("Covariance computation is disabled. Set compute_cov=True to enable.")
return self.running_cov
[docs]
class eDataNorm(DataNorm, EquivariantModule):
r"""Equivariant version of DataNorm using group-theoretic symmetry-aware statistics.
This module extends :class:`DataNorm` to work with equivariant data by computing
statistics that respect the symmetry structure defined by a group representation.
It maintains the same API and modes of operation as ``DataNorm`` while using
symmetry-aware mean, variance, and covariance computations from
:mod:`symm_learning.stats`.
**Mathematical Formulation:**
The equivariant normalization follows the same mathematical form as ``DataNorm``:
.. math::
y = \begin{cases}
x - \mu_{\text{equiv}} & \text{if } \texttt{only\_centering} = \text{True} \\
\frac{x - \mu_{\text{equiv}}}{\sqrt{\sigma^2_{\text{equiv}} + \epsilon}} & \text{otherwise}
\end{cases}
However, the statistics :math:`\mu_{\text{equiv}}` and :math:`\sigma^2_{\text{equiv}}`
are computed using symmetry-aware estimators:
- **Mean**: Projected onto the :math:`G`-invariant subspace
- **Variance**: Constrained to be constant within each irreducible subspace
- **Covariance**: Respects the block-diagonal structure imposed by the representation
**Symmetry Properties:**
The computed statistics satisfy equivariance and invariance properties:
- :math:`\mathbb{E}[g \cdot x] = g \cdot \mathbb{E}[x]` (mean equivariance)
- :math:`\text{Var}[g \cdot x] = \text{Var}[x]` (variance invariance)
- :math:`\text{Cov}[g \cdot x, g \cdot y] = g \cdot \text{Cov}[x, y] \cdot g^T` (covariance equivariance)
**Input/Output Types:**
Unlike ``DataNorm`` which operates on raw tensors, ``eDataNorm`` processes
:class:`escnn.nn.GeometricTensor` objects that encode the group representation
information along with the tensor data.
Args:
in_type (escnn.nn.FieldType): The field type defining the input's group
representation structure. The output type will be the same as the input type.
eps (float, optional): Small constant added to the denominator for numerical
stability. Only used when ``only_centering=False``. Default: ``1e-6``.
only_centering (bool, optional): If ``True``, only centers the data using
equivariant mean without scaling. Default: ``False``.
compute_cov (bool, optional): If ``True``, computes and tracks the equivariant
covariance matrix. Default: ``False``.
momentum (float, optional): Momentum factor for exponential moving average of
running statistics. Must be greater than 0. Setting to ``1.0`` effectively
uses only batch statistics. Default: ``1.0``.
Shape:
- Input: :class:`escnn.nn.GeometricTensor` with tensor shape :math:`(N, D)` or :math:`(N, D, L)` where:
- :math:`N` is the batch size
- :math:`D` is ``in_type.size`` (total representation dimension)
- :math:`L` is the sequence length (optional, for 3D inputs)
- Output: :class:`escnn.nn.GeometricTensor` with the same type and shape as input
Methods:
export() -> DataNorm: Exports the current state to a standard ``DataNorm`` layer
that can operate on raw tensors, transferring all learned statistics.
Examples:
>>> from escnn import gspaces, nn as escnn_nn
>>> from escnn.group import CyclicGroup
>>>
>>> # Define group and representation
>>> G = CyclicGroup(4)
>>> gspace = gspaces.no_base_space(G)
>>> in_type = escnn_nn.FieldType(gspace, [G.regular_representation] * 2)
>>>
>>> # Create equivariant normalization layer
>>> norm = eDataNorm(in_type=in_type, compute_cov=True)
>>>
>>> # Process equivariant data
>>> x_tensor = torch.randn(16, in_type.size) # Raw tensor data
>>> x_geom = in_type(x_tensor) # Wrap in GeometricTensor
>>> y_geom = norm(x_geom) # Normalized GeometricTensor
>>>
>>> # Export to standard DataNorm
>>> standard_norm = norm.export()
>>> y_tensor = standard_norm(x_tensor) # Same result on raw tensor
Note:
This layer inherits all modes of operation from :class:`DataNorm` (running statistics,
fixed statistics, centering-only, covariance computation) while computing all
statistics using group-theoretic constraints. The statistics respect the irreducible
decomposition of the input representation, ensuring that symmetries are preserved
throughout the normalization process.
See Also:
:class:`DataNorm`: The base normalization layer for standard (non-equivariant) data.
:func:`symm_learning.stats.var_mean`: Equivariant mean and variance computation.
:func:`symm_learning.stats.cov`: Equivariant covariance computation.
"""
def __init__(
self,
in_type: FieldType,
eps: float = 1e-6,
only_centering: bool = False,
compute_cov: bool = False,
momentum: float = 1.0,
):
# Initialize DataNorm with the field type size
super().__init__(
num_features=in_type.size,
eps=eps,
only_centering=only_centering,
compute_cov=compute_cov,
momentum=momentum,
)
# Store EquivariantModule-specific attributes first
self.in_type = in_type
self.out_type = in_type
self._rep_x = in_type.representation
def _compute_batch_stats(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute equivariant batch statistics using symm_learning.stats."""
batch_var, batch_mean = symm_learning.stats.var_mean(x, rep_x=self._rep_x)
if self.only_centering:
batch_var = torch.ones_like(batch_mean)
return batch_mean, batch_var
def _compute_batch_cov(self, x: torch.Tensor) -> torch.Tensor:
"""Compute equivariant batch covariance using symm_learning.stats."""
return symm_learning.stats.cov(x=x, y=x, rep_x=self._rep_x, rep_y=self._rep_x)
[docs]
def forward(self, x: GeometricTensor) -> GeometricTensor:
"""Apply equivariant normalization to the input GeometricTensor."""
assert x.type == self.in_type, f"Input type {x.type} does not match expected type {self.in_type}."
# Apply DataNorm forward to the tensor data
normalized_tensor = super().forward(x.tensor)
# Return as GeometricTensor
return self.out_type(normalized_tensor)
[docs]
def evaluate_output_shape(self, input_shape):
"""Return the same shape as input for EquivariantModule compatibility."""
return input_shape
[docs]
def check_equivariance(self, atol=1e-5, rtol=1e-5):
"""Check the equivariance of the normalization layer."""
was_training = self.training
batch_size = 50
self.train()
# Process a few batches to get some running statistics
for _ in range(3):
x = torch.randn(batch_size, self.in_type.size)
x_geom = self.in_type(x)
_ = self(x_geom)
self.eval()
# Test equivariance
x = torch.randn(batch_size, self.in_type.size)
x_geom = self.in_type(x)
for _ in range(5):
g = self.in_type.representation.group.sample()
if g == self.in_type.representation.group.identity:
continue
gx_geom = x_geom.transform(g)
y = self(x_geom)
gy = self(gx_geom)
gy_expected = y.transform(g)
assert torch.allclose(gy.tensor, gy_expected.tensor, atol=atol, rtol=rtol), (
f"Equivariance check failed for group element {g}"
)
self.train(was_training)
[docs]
def export(self) -> DataNorm:
"""Export to a standard DataNorm layer."""
exported = DataNorm(
num_features=self.num_features,
eps=self.eps,
only_centering=self.only_centering,
compute_cov=self.compute_cov,
momentum=self.momentum,
)
# Transfer state
exported.running_mean.data = self.running_mean.clone()
exported.running_var.data = self.running_var.clone()
exported.num_batches_tracked.data = self.num_batches_tracked.clone()
if self.compute_cov and hasattr(self, "running_cov"):
exported.running_cov.data = self.running_cov.clone()
exported._last_cov = self._last_cov
exported.eval()
return exported