Source code for symm_learning.models.time_cnn.ecnn_encoder

from __future__ import annotations

import math
from typing import Iterable

import torch
from escnn.group import Representation

from symm_learning.models.emlp import eMLP, iMLP
from symm_learning.nn import eConv1d, eRMSNorm
from symm_learning.representation_theory import direct_sum


class _eChannelRMSNorm(torch.nn.Module):
    """Apply eRMSNorm over the channel dimension for tensors shaped (B, C, L)."""

    def __init__(self, rep: Representation, eps: float = 1e-6):
        super().__init__()
        self.rep = rep
        self.norm = eRMSNorm(rep, eps=eps, equiv_affine=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normalize over channels; treat time steps as leading dimensions.
        y = x.permute(0, 2, 1)  # (B, L, C)
        y = self.norm(y)
        return y.permute(0, 2, 1)  # (B, C, L)


[docs] class eTimeCNNEncoder(torch.nn.Module): r"""Equivariant 1D CNN encoder built from channel-equivariant blocks. Inputs are plain tensors of shape ``(N, in_rep.size, H)``. Each conv block halves the time horizon via stride-2 convolution; optional eRMSNorm and pointwise activation follow. The flattened feature map feeds either an equivariant head (:class:`eMLP`) or an invariant head (:class:`iMLP`) depending on whether ``out_rep`` contains only the trivial irrep. The encoder defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{H} \to \mathcal{Y}, where :math:`H` is the input horizon, :math:`\mathcal{X}` is the channel feature space transforming by :math:`\rho_{\mathcal{X}}`, and :math:`\mathcal{Y}` transforms by :math:`\rho_{\mathcal{Y}}`. Functional constraint (equivariant head): .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x}) = \rho_{\mathcal{Y}}(g)\mathbf{f}_{\mathbf{\theta}}(\mathbf{x}) \quad \forall g\in\mathbb{G}. If ``out_rep`` is trivial-only, the head is invariant: .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x}) = \mathbf{f}_{\mathbf{\theta}}(\mathbf{x}) \quad \forall g\in\mathbb{G}. """ def __init__( self, in_rep: Representation, out_rep: Representation, hidden_channels: list[int], time_horizon: int, activation: torch.nn.Module = torch.nn.ReLU(), batch_norm: bool = False, bias: bool = True, mlp_hidden: list[int] = (128,), downsample: str = "stride", append_last_frame: bool = False, init_scheme: str | None = "xavier_uniform", ) -> None: r"""Create an equivariant time-series CNN encoder. Args: in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}` defining the group action on the input channels. out_rep (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\text{out}}`. If it contains only trivial irreps, an :class:`iMLP` head is used; otherwise an :class:`eMLP` head is used. hidden_channels: List of output channel counts for each convolution block. time_horizon: Length of the input time series (number of frames). activation: Non-linearity applied after every convolution block. batch_norm: Whether to include channel-wise RMS normalization. bias: Whether to include bias in convolutions and linear heads. mlp_hidden: Hidden layer widths for the final MLP head. downsample: Downsampling strategy, either ``'stride'`` (stride-2 conv) or ``'pooling'`` (max pool). append_last_frame: Whether to concatenate the last frame of the input to the encoding before the head. init_scheme: Initialization scheme for equivariant layers. """ super().__init__() assert len(hidden_channels) > 0, "At least one conv block is required" assert downsample in {"stride", "pooling"}, "downsample must be 'stride' or 'pooling'" self.in_rep = in_rep self.out_rep = out_rep self.time_horizon = int(time_horizon) self.append_last_frame = append_last_frame self.downsample = downsample G = in_rep.group reg_rep = G.regular_representation layers: list[torch.nn.Module] = [] cnn_in_rep = in_rep h = self.time_horizon for c_out in hidden_channels: multiplicity = max(1, math.ceil(c_out / reg_rep.size)) cnn_out_rep = direct_sum([reg_rep] * multiplicity) if self.downsample == "stride": layers.append( eConv1d( cnn_in_rep, cnn_out_rep, kernel_size=3, stride=2, padding=1, bias=bias, init_scheme=init_scheme ) ) if batch_norm: layers.append(_eChannelRMSNorm(cnn_out_rep)) layers.append(activation) h = (h + 1) // 2 else: # pooling layers.append( eConv1d( cnn_in_rep, cnn_out_rep, kernel_size=3, stride=1, padding=1, bias=bias, init_scheme=init_scheme ) ) if batch_norm: layers.append(_eChannelRMSNorm(cnn_out_rep)) layers.append(activation) layers.append(torch.nn.MaxPool1d(kernel_size=2, stride=2)) h = h // 2 cnn_in_rep = cnn_out_rep self.feature_layers = torch.nn.Sequential(*layers) assert h > 0, f"Horizon {self.time_horizon} too short for {len(hidden_channels)} blocks" self.time_horizon_out = h # Head input representation: repeat conv out_rep for each remaining time step head_rep = direct_sum([cnn_in_rep] * self.time_horizon_out) if self.append_last_frame: head_rep = direct_sum([head_rep, in_rep]) self.head_in_rep = head_rep # Choose head: invariant if out_rep is trivial-only, else equivariant trivial_id = G.trivial_representation.id invariant_head = set(out_rep.irreps) == {trivial_id} if invariant_head: self.head = iMLP( in_rep=head_rep, out_dim=out_rep.size, hidden_units=list(mlp_hidden), activation=activation, bias=bias, init_scheme=init_scheme, ) else: self.head = eMLP( in_rep=head_rep, out_rep=out_rep, hidden_units=list(mlp_hidden), activation=activation, bias=bias, init_scheme=init_scheme, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Encode input of shape ``(N, in_rep.size, H)`` into ``(N, out_rep.size)``.""" assert x.shape[-2:] == (self.in_rep.size, self.time_horizon), ( f"Expected input shape (..., {self.in_rep.size}, {self.time_horizon}), got {x.shape}" ) feats = self.feature_layers(x) z = feats.permute(0, 2, 1).reshape(feats.size(0), -1) if self.append_last_frame: z = torch.cat([z, x[:, :, -1]], dim=1) return self.head(z)
[docs] @torch.no_grad() def check_equivariance(self, atol: float = 1e-5, rtol: float = 1e-5): """Check equivariance under channel actions of the underlying group.""" import random G = self.in_rep.group B, L = 10, self.time_horizon dtype, device = next(self.head.parameters()).dtype, next(self.head.parameters()).device x = torch.randn(B, self.in_rep.size, L, device=device, dtype=dtype) y = self(x) elements = set(G.elements) for _ in range(10): g = random.choice(tuple(elements)) rho_in = torch.tensor(self.in_rep(g), dtype=x.dtype, device=x.device) rho_out = torch.tensor(self.out_rep(g), dtype=y.dtype, device=y.device) gx = torch.einsum("ij,bjl->bil", rho_in, x) gy = self(gx) gy_exp = torch.einsum("ij,bj->bi", rho_out, y) assert torch.allclose(gy_exp, gy, atol=atol, rtol=rtol), ( f"Equivariance failed for group element {g} with max error {(gy_exp - gy).abs().max().item():.3e}" ) elements.remove(g) if len(elements) == 0: break