Source code for symm_learning.models.control.econd_transformer

from __future__ import annotations

import logging
import random
import time
from typing import Literal, Optional, Tuple

import torch
from escnn.group import Representation

import symm_learning
from symm_learning.linalg import invariant_orthogonal_projector
from symm_learning.models.control.cond_transformer import (
    GenCondRegressor,
    SinusoidalPosEmb,
    build_causal_attention_masks,
    build_cond_positions,
    build_input_positions,
)
from symm_learning.nn.module import eModule
from symm_learning.representation_theory import direct_sum
from symm_learning.utils import module_memory

logger = logging.getLogger(__name__)


[docs] class eCondTransformer(eModule, GenCondRegressor): r"""Equivariant encoder/decoder Transformer with configurable positional attention. Let :math:`A := \texttt{num\_cond\_layers}` and :math:`B := \texttt{num\_layers}`. This module is the equivariant counterpart of :class:`~symm_learning.models.control.cond_transformer.CondTransformer`: an encoder/decoder Transformer with :math:`A` conditioning encoder layers and :math:`B` decoder layers, following the architecture introduced in *Attention Is All You Need* by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, and Polosukhin (NeurIPS 2017), while constraining every learnable map to respect the prescribed group actions. The conditioning stream is assembled as .. math:: [k, \mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}], where :math:`k` is the inference-time optimisation step token and :math:`\mathbf{z}_{-(T_z - 1)}, \ldots, \mathbf{z}_{0}` are the conditioning tokens ordered from oldest to most recent observation. The decoder predicts the target sequence .. math:: [\mathbf{x}_{0}, \ldots, \mathbf{x}_{T_x - 1}], ordered from the first action to the last action in the predicted horizon. Supported positional encodings are: ``"additive_absolute"`` Uses :class:`~symm_learning.nn.activation.eAdditivePosMultiheadAttention`. Learned absolute positions are added in the equivariant embedding space before self-attention or cross-attention. ``"additive_relative"`` Uses :class:`~symm_learning.nn.activation.eAdditiveRelMultiheadAttention`. Learned relative distance biases are injected into attention logits, preserving the time-translation structure of the sequence while keeping the feature maps equivariant. ``"none"`` Uses :class:`~symm_learning.nn.activation.eMultiheadAttention` with no explicit positional encoding. Temporal assumptions: * :math:`Z` must already be ordered in time from past to present. * :math:`X` must already be ordered from the first predicted action to the last predicted action. * For ``"additive_relative"``, the last conditioning token :math:`\mathbf{z}_{0}` and the first action token :math:`\mathbf{x}_{0}` are both placed at time index :math:`0`, so cross-attention is anchored at the present time. * The optimisation-step token :math:`k` is prepended to the conditioning memory, but it is not treated as part of the observation timeline. Equivariance is enforced by embedding tokens into a representation space built from copies of the regular representation, projecting the scalar step embedding onto the invariant subspace, and using equivariant encoder, decoder, normalization, and head layers throughout. The resulting conditional map satisfies .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{X}_k,\, \rho_{\mathcal{Z}}(g)\mathbf{Z},\, k) = \rho_{\mathcal{Y}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{X}_k,\mathbf{Z},k), \qquad \forall g \in \mathbb{G}. """ def __init__( self, in_rep: Representation, cond_rep: Representation, out_rep: Optional[Representation], in_horizon: int, cond_horizon: int, num_layers: int, num_attention_heads: int, embedding_dim: int, p_drop_emb: float = 0.1, p_drop_attn: float = 0.1, causal_attn: bool = False, num_cond_layers: int = 0, pos_encoding: Literal["additive_absolute", "additive_relative", "none"] = "additive_absolute", norm_first: bool = True, norm_module: Literal["layernorm", "rmsnorm"] = "rmsnorm", init_scheme: str = "xavier_uniform", ) -> None: r"""Create an equivariant conditional transformer regressor. Args: in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` of the input tokens. cond_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{cond}}` of the conditioning tokens. out_rep (:class:`~escnn.group.Representation`, optional): Output representation :math:`\rho_{\text{out}}`. Defaults to ``in_rep`` if ``None``. in_horizon: Maximum length of the input sequence. cond_horizon: Maximum length of the conditioning sequence. num_layers: Number of transformer decoder layers in the main generation trunk. num_attention_heads: Number of attention heads in transformer layers. embedding_dim: Dimension of the regular representation embedding space. Must be a multiple of the group order. p_drop_emb: Dropout probability for embeddings. p_drop_attn: Dropout probability inside attention blocks. causal_attn: Whether to mask future tokens (causal masking). num_cond_layers: Number of transformer encoder layers for processing conditioning tokens. If 0, an eMLP is used instead. pos_encoding: Positional attention backend (``"additive_absolute"``, ``"additive_relative"``, or ``"none"``). norm_first: Whether to apply normalization before each residual branch. norm_module: Normalization layer type (``'layernorm'`` or ``'rmsnorm'``). init_scheme: Initialization scheme for equivariant layers. """ out_rep = out_rep or in_rep super().__init__(in_rep.size, out_rep.size, cond_rep.size) self.in_rep = in_rep self.out_rep = out_rep self.cond_rep = cond_rep self.in_horizon = in_horizon self.cond_horizon = cond_horizon self.cond_token_horizon = cond_horizon + 1 self.num_layers = num_layers self.embedding_dim = embedding_dim self.pos_encoding = pos_encoding self.dropout = torch.nn.Dropout(p_drop_emb) G = in_rep.group assert cond_rep.group == G == out_rep.group, "All representations must belong to the same group" if embedding_dim % G.order() != 0: raise ValueError(f"embedding_dim ({embedding_dim}) must be a multiple of the group order ({G.order()})") regular_copies = max(1, embedding_dim // G.order()) self.embedding_rep = direct_sum([G.regular_representation] * regular_copies) self.register_buffer("invariant_projector", invariant_orthogonal_projector(self.embedding_rep)) self.input_emb = symm_learning.nn.eLinear(in_rep, self.embedding_rep, bias=True, init_scheme=None) self.cond_emb = symm_learning.nn.eLinear(cond_rep, self.embedding_rep, bias=True, init_scheme=None) self.opt_time_emb = SinusoidalPosEmb(embedding_dim) max_pos_len = max(self.in_horizon, self.cond_token_horizon) max_rel_distance = self.in_horizon + self.cond_token_horizon - 2 def _build_attn(): if pos_encoding == "additive_absolute": return symm_learning.nn.eAdditivePosMultiheadAttention( in_rep=self.embedding_rep, num_heads=num_attention_heads, max_len=max_pos_len, dropout=p_drop_attn, bias=True, init_scheme=None, ) elif pos_encoding == "additive_relative": return symm_learning.nn.eAdditiveRelMultiheadAttention( in_rep=self.embedding_rep, num_heads=num_attention_heads, max_distance=max_rel_distance, dropout=p_drop_attn, bias=True, init_scheme=None, ) elif pos_encoding == "none": return symm_learning.nn.eMultiheadAttention( in_rep=self.embedding_rep, num_heads=num_attention_heads, dropout=p_drop_attn, bias=True, init_scheme=None, ) else: raise ValueError( f"Unknown pos_encoding={pos_encoding!r}. Expected " "'additive_absolute', 'additive_relative', or 'none'." ) # Conditioning encoder if num_cond_layers > 0: encoder_layer = symm_learning.nn.eTransformerEncoderLayer( in_rep=self.embedding_rep, self_attn=_build_attn(), dim_feedforward=4 * embedding_dim, dropout=p_drop_attn, activation=torch.nn.GELU(), norm_first=norm_first, norm_module=norm_module, init_scheme=None, ) logger.debug( f"Initializing {num_cond_layers} layers of eTransformerEncoderLayer of " f"{sum(p.numel() for p in encoder_layer.parameters()) / 1e6:.2f}M parameters each" ) self.encoder = symm_learning.nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=num_cond_layers, enable_nested_tensor=False ) else: hidden_rep = direct_sum([self.embedding_rep] * 4) logger.debug(f"Initializing eMLP encoder with hidden representation {hidden_rep}") self.encoder = torch.nn.Sequential( symm_learning.nn.eLinear(in_rep=self.embedding_rep, out_rep=hidden_rep, bias=True, init_scheme=None), torch.nn.Mish(), symm_learning.nn.eLinear(in_rep=hidden_rep, out_rep=self.embedding_rep, bias=True, init_scheme=None), ) # Decoder decoder_layer = symm_learning.nn.eTransformerDecoderLayer( in_rep=self.embedding_rep, self_attn=_build_attn(), multihead_attn=_build_attn(), dim_feedforward=4 * embedding_dim, dropout=p_drop_attn, activation=torch.nn.GELU(), norm_first=norm_first, norm_module=norm_module, init_scheme=None, ) logger.debug(f"Initializing {num_layers} layers of eTransformerDecoderLayer") self.decoder = symm_learning.nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_layers) # Self-Attention and Cross-Attention mask. # Cross-attention is used to compute updates to the action vector based on the conditioing tokens # composed of a inference-time optimization step token and the observation conditioning tokens. if causal_attn: self_att_mask, cross_att_mask = build_causal_attention_masks(in_horizon, self.cond_token_horizon) self.register_buffer("self_att_mask", self_att_mask) self.register_buffer("cross_att_mask", cross_att_mask) else: self.self_att_mask = None self.cross_att_mask = None # Decoder head if norm_module == "layernorm": self.layer_norm = symm_learning.nn.eLayerNorm(self.embedding_rep, eps=1e-5, equiv_affine=True, bias=True) else: # rmsnorm self.layer_norm = symm_learning.nn.eRMSNorm(self.embedding_rep, eps=1e-5, equiv_affine=True) self.head = symm_learning.nn.eLinear(self.embedding_rep, out_rep, bias=True, init_scheme=None) self.reset_parameters(scheme=init_scheme) trainable_mem, non_trainable_mem = module_memory(self, units="MiB") logger.info( f"[{self.__class__.__name__}]: {trainable_mem:.3f} MiB trainable, {non_trainable_mem:.3f} " f"MiB non-trainable parameters." )
[docs] @torch.no_grad() def reset_parameters(self, scheme="xavier_uniform") -> None: """Re-initialize all parameters.""" logger.debug(f"Resetting parameters of {self.__class__.__name__} with scheme: {scheme}") # Initialize eLinear layers. self.input_emb.reset_parameters(scheme=scheme) self.cond_emb.reset_parameters(scheme=scheme) self.head.reset_parameters(scheme=scheme) # Initialize final layer norm and head. self.layer_norm.reset_parameters() # Initalize conditional encoder layers. if isinstance(self.encoder, symm_learning.nn.TransformerEncoder): for i, layer in enumerate(self.encoder.layers, start=1): assert isinstance(layer, symm_learning.nn.eTransformerEncoderLayer) logger.debug(f"Resetting encoder layer {i}:[{layer.__class__.__name__}] with scheme: {scheme}") layer.reset_parameters(scheme=scheme) else: # eMLP. for i, module in enumerate(self.encoder, start=1): if isinstance(module, symm_learning.nn.eLinear): logger.debug(f"Resetting encoder module {i}:[{module.__class__.__name__}] with scheme: {scheme}") module.reset_parameters(scheme=scheme) # Initialize decoder layers. for i, layer in enumerate(self.decoder.layers, start=1): assert isinstance(layer, symm_learning.nn.eTransformerDecoderLayer) logger.debug(f"Resetting decoder layer {i}:[{layer.__class__.__name__}] with scheme: {scheme}") layer.reset_parameters(scheme=scheme) logger.info(f"[{self.__class__.__name__}]: parameters initialized with `{scheme}` scheme.")
[docs] def get_optim_groups(self, weight_decay: float = 1e-3): """Todo.""" decay = set() no_decay = set() whitelist_weight_modules = (symm_learning.nn.eLinear, symm_learning.nn.eMultiheadAttention) blacklist_weight_modules = (symm_learning.nn.eLayerNorm, torch.nn.Embedding) for module_name, m in self.named_modules(): for param_name, p in m.named_parameters(): if not p.requires_grad: continue fpn = f"{module_name}.{param_name}" if module_name else param_name if param_name.endswith("bias") or param_name.startswith("bias"): no_decay.add(fpn) elif param_name.endswith("weight") and isinstance(m, whitelist_weight_modules): decay.add(fpn) elif ( param_name.endswith("weight") and isinstance(m, blacklist_weight_modules) or param_name in {"pos_emb", "rel_bias"} ): no_decay.add(fpn) else: raise ValueError(f"Unrecognized parameter {fpn} in module {module_name}") param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, f"parameters {inter_params} in both decay/no_decay" assert len(param_dict.keys() - union_params) == 0, ( f"parameters {param_dict.keys() - union_params} not separated into decay/no_decay" ) optim_groups = [ {"params": [param_dict[pn] for pn in sorted(decay)], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0}, ] return optim_groups
def configure_optimizers( # noqa: D102 self, learning_rate: float = 1e-4, weight_decay: float = 1e-3, betas: tuple[float, float] = (0.9, 0.95) ): # noqa: D102 optim_groups = self.get_optim_groups(weight_decay=weight_decay) return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
[docs] def forward( self, X: torch.Tensor, opt_step: torch.Tensor | float | int, Z: torch.Tensor, ): r"""Forward pass approximating :math:`V_k = f(X_k, Z, k)`.""" assert X.dim() == 3 and X.shape[-1] == self.in_rep.size and X.shape[1] <= self.in_horizon, ( f"Expected X shape (B, Tx<={self.in_horizon}, {self.in_rep.size}) got {X.shape}" ) assert Z.dim() == 3 and Z.shape[-1] == self.cond_rep.size and Z.shape[1] <= self.cond_horizon, ( f"Expected Z shape (B, Tz<={self.cond_horizon}, {self.cond_rep.size}) got {Z.shape}" ) batch_size = X.shape[0] # 1. Inference-time optimization step embedding (k). First conditioning token. if isinstance(opt_step, torch.Tensor): opt_steps = opt_step.to(device=X.device, dtype=torch.float32) else: opt_steps = torch.scalar_tensor(opt_step, device=X.device, dtype=torch.float32) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML opt_steps = opt_steps.reshape(-1).expand(batch_size) opt_time_emb = self.opt_time_emb(opt_steps).unsqueeze(1) # (B, 1, D) # Project time embedding onto embedding space's invariant subspace opt_time_emb = torch.einsum("ij,...j->...i", self.invariant_projector, opt_time_emb) # 2. Conditioning variable Z embedding/tokenization z_cond_emb = self.cond_emb(Z) cond_embeddings = torch.cat([opt_time_emb, z_cond_emb], dim=1) cond_tokens = self.dropout(cond_embeddings) cond_token_horizon = cond_embeddings.shape[1] cond_positions, cond_position_mask = build_cond_positions( self.pos_encoding, cond_token_horizon, device=X.device ) if isinstance(self.encoder, symm_learning.nn.TransformerEncoder): cond_tokens = self.encoder( cond_tokens, src_positions=cond_positions, src_position_mask=cond_position_mask, ) else: cond_tokens = self.encoder(cond_tokens) # 3. Input embedding/tokenization input_tokens = self.dropout(self.input_emb(X)) # 4. Transformer encoder of input tokens with self-attention and cross-attention to cond tokens input_horizon = input_tokens.shape[1] input_positions = torch.arange(input_horizon, device=X.device) input_position_mask = torch.ones(input_horizon, device=X.device, dtype=torch.bool) out_tokens = self.decoder( tgt=input_tokens, memory=cond_tokens, tgt_mask=self.self_att_mask, memory_mask=self.cross_att_mask, tgt_positions=input_positions, tgt_position_mask=input_position_mask, memory_positions=cond_positions, memory_position_mask=cond_position_mask, ) # 5. Regression head projecting to output dimension. out_tokens = self.layer_norm(out_tokens) out = self.head(out_tokens) # (B, Tx, out_dim) return out
@torch.no_grad() def check_equivariance( # noqa: D102 self, batch_size: int = 10, in_len: int = 10, cond_len: int = 5, atol: float = 1e-4, rtol: float = 1e-4, ) -> None: import escnn G = self.in_rep.group in_len = min(in_len, self.in_horizon) cond_len = min(cond_len, self.cond_horizon) training_mode = self.training self.eval() def act(rep: Representation, g: escnn.group.GroupElement, x: torch.Tensor) -> torch.Tensor: mat = torch.tensor(rep(g), dtype=x.dtype, device=x.device) return torch.einsum("ij,...j->...i", mat, x) device = self.invariant_projector.device dtype = self.invariant_projector.dtype for _ in range(min(10, G.order())): g = random.choice(list(G.elements[1:])) # skip identity X = torch.randn(batch_size, in_len, self.in_rep.size, device=device, dtype=dtype) Z = torch.randn(batch_size, cond_len, self.cond_rep.size, device=device, dtype=dtype) k = torch.randn(batch_size, device=device, dtype=dtype) Y = self(X=X, Z=Z, opt_step=k) # Evaluate on symmetric points. gX = act(self.in_rep, g, X) gZ = act(self.cond_rep, g, Z) gY = self(X=gX, Z=gZ, opt_step=k) gY_expected = act(self.out_rep, g, Y) assert torch.allclose(gY, gY_expected, atol=atol, rtol=rtol), ( f"Equivariance test failed for group element {g}.\n" f"Max absolute difference: {torch.max(torch.abs(gY - gY_expected))}\n" ) if training_mode: self.train()
if __name__ == "__main__": from symm_learning.utils import describe_memory # Set logging to debug logging.basicConfig(level=logging.DEBUG) # Set debug message to [name][level][time][message] formatter = logging.Formatter("[%(name)s][%(levelname)s][%(asctime)s]: %(message)s") for handler in logging.getLogger().handlers: handler.setFormatter(formatter) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 from escnn.group import CyclicGroup, Icosahedral # G = Icosahedral() # # G = CyclicGroup(2) # d = 70 # # m = int(70 // G.order()) # m = 1 # in_rep = directsum([G.regular_representation] * m) # dim 8 # cond_rep = in_rep # out_rep = in_rep # Tx, Tz = 8, 6 # model = eCondTransformerRegressor( # in_rep=in_rep, # cond_rep=cond_rep, # out_rep=out_rep, # in_horizon=Tx, # cond_horizon=Tz, # num_layers=8, # num_attention_heads=1, # embedding_dim=G.order() * m, # num_cond_layers=0, # ).to(device=device, dtype=dtype) # model.check_equivariance() # G = Icosahedral() G = CyclicGroup(2) m = 5 embedding_dim = G.order() * m * 4 in_rep = direct_sum([G.regular_representation] * m) # dim 8 print(in_rep.size) cond_rep = in_rep out_rep = in_rep Tx, Tz = 8, 6 start_time = time.time() model = eCondTransformer( in_rep=in_rep, cond_rep=cond_rep, out_rep=out_rep, in_horizon=Tx, cond_horizon=Tz, num_layers=5, num_attention_heads=1, embedding_dim=embedding_dim, num_cond_layers=0, ) # print(describe_memory("eCondTransformer", model)) print(f"Model initialized in {time.time() - start_time:.2f} seconds") model.to(device=device, dtype=dtype) model.eval() model.check_equivariance(atol=1e-2, rtol=1e-2) print("Equivariance test passed!")