Source code for symm_learning.models.time_cnn.cnn_encoder

from __future__ import annotations

import torch


[docs] class TimeCNNEncoder(torch.nn.Module): """1D CNN encoder for inputs of shape (N, in_dim, H). The model applies a stack of 1D conv blocks over time. Each block halves the time horizon using either stride-2 convolution or max-pooling. The final feature map of shape (N, C_L, H_out) is flattened and passed to a 1-hidden-layer MLP head. Convolutional features are time-translation equivariant; flattening preserves this up to a fixed permutation under shifts. The optional concatenation of the last input frame and the final MLP can break output equivariance. Args: in_dim: Input channel dimension. out_dim: Output feature dimension. hidden_channels: Channels per conv block; depth equals len(hidden_channels). horizon: Input sequence length H. activation: Activation module or list (one per block). If a single module is given, it is replicated for all blocks. batch_norm: If True, add BatchNorm1d after each convolution. bias: Use bias in conv/linear layers. mlp_hidden: Hidden units of the final MLP head (list for deeper heads). downsample: "stride" (default) or "pooling"; each block halves H. append_last_frame: If True, concatenate x[:, :, -1] (N, in_dim) to flattened conv features before the MLP. Returns: Tensor of shape (N, out_dim). """ def __init__( self, in_dim: int, out_dim: int, hidden_channels: list[int], horizon: int, activation: torch.nn.Module | list[torch.nn.Module] = torch.nn.ReLU(), batch_norm: bool = False, bias: bool = True, mlp_hidden: list[int] = [128], downsample: str = "stride", append_last_frame: bool = False, ) -> None: super().__init__() assert hasattr(hidden_channels, "__iter__") and hasattr(hidden_channels, "__len__"), ( "hidden_channels must be a list of integers" ) assert len(hidden_channels) > 0, "At least one conv block is required" assert downsample in {"stride", "pooling"}, "downsample must be 'stride' or 'pooling'" self.in_dim, self.out_dim = in_dim, out_dim self.horizon_in = int(horizon) self.batch_norm = batch_norm self.bias = bias self.downsample = downsample self.append_last_frame = append_last_frame # Prepare per-block activations (minimal logic) if not isinstance(activation, list): conv_acts = [activation] * len(hidden_channels) else: assert len(activation) == len(hidden_channels), f"{len(activation)} != {len(hidden_channels)}: " conv_acts = activation head_act = conv_acts[-1] # Build conv feature extractor; kernel=3, padding=1 (time length preserved before downsampling) conv_layers = [] cin = in_dim h = self.horizon_in for cout, act in zip(hidden_channels, conv_acts): if self.downsample == "stride": # Conv1d with stride=2 halves time: L_out = floor((L+1)/2) for k=3,p=1 conv_layers.append(torch.nn.Conv1d(cin, cout, kernel_size=3, stride=2, padding=1, bias=bias)) if batch_norm: conv_layers.append(torch.nn.BatchNorm1d(cout)) conv_layers.append(act) h = (h + 1) // 2 else: # pooling # Conv stride=1 keeps time, then MaxPool1d(2) halves: floor(L/2) conv_layers.append(torch.nn.Conv1d(cin, cout, kernel_size=3, stride=1, padding=1, bias=bias)) if batch_norm: conv_layers.append(torch.nn.BatchNorm1d(cout)) conv_layers.append(act) conv_layers.append(torch.nn.MaxPool1d(kernel_size=2, stride=2)) h = h // 2 cin = cout self.feature_extractor = torch.nn.Sequential(*conv_layers) self.horizon_out = max(1, h) flat_dim = cin * self.horizon_out if self.append_last_frame: flat_dim += in_dim # MLP head (supports multiple hidden layers) head_layers = [] prev = flat_dim for hidden in mlp_hidden: head_layers.append(torch.nn.Linear(prev, hidden, bias=bias)) # replicate a fresh activation instance of the same type as head_act head_layers.append(type(head_act)()) prev = hidden head_layers.append(torch.nn.Linear(prev, out_dim, bias=bias)) self.head = torch.nn.Sequential(*head_layers) torch.nn.init.orthogonal_(self.head[-1].weight)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. :param x: Input tensor of shape ``(N, in_dim, H)``. :type x: torch.Tensor :returns: Encoded tensor of shape ``(N, out_dim)``. :rtype: torch.Tensor """ z = self.feature_extractor(x) z = z.view(z.size(0), -1) if self.append_last_frame: last = x[:, :, -1] z = torch.cat([z, last], dim=1) return self.head(z)