Source code for symm_learning.models.diffusion.econd_transformer_regressor

from __future__ import annotations

import logging
import random
import time
from typing import Literal, Optional, Tuple

import torch
from escnn.group import Representation

import symm_learning
from symm_learning.linalg import invariant_orthogonal_projector
from symm_learning.models.diffusion.cond_transformer_regressor import GenCondRegressor
from symm_learning.models.diffusion.cond_unet1d import SinusoidalPosEmb
from symm_learning.representation_theory import direct_sum
from symm_learning.utils import module_memory

logger = logging.getLogger(__name__)


[docs] class eCondTransformerRegressor(GenCondRegressor): r"""Equivariant analogue of the conditional transformer regressor baseline. This module mirrors :class:`~symm_learning.models.diffusion.cond_transformer_regressor.CondTransformerRegressor` while enforcing equivariance constraints. Tokens transforming according to ``in_rep`` are embedded into an ``embedding_rep`` space built from copies of the regular representation so that :class:`~symm_learning.models.transformer.etransformer.eTransformerEncoderLayer`/ :class:`~symm_learning.models.transformer.etransformer.eTransformerDecoderLayer` can be used directly. Positional encodings and timestep embeddings are projected onto the invariant subspace so they can be added to equivariant tokens without breaking symmetry. The model defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_x} \times \mathcal{Z}^{T_z} \times \mathbb{R} \to \mathcal{Y}^{T_x}. Functional equivariance constraint: .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{X}_k,\, \rho_{\mathcal{Z}}(g)\mathbf{Z},\, k) = \rho_{\mathcal{Y}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{X}_k,\mathbf{Z},k), \quad \forall g\in\mathbb{G}. """ def __init__( self, in_rep: Representation, cond_rep: Representation, out_rep: Optional[Representation], in_horizon: int, cond_horizon: int, num_layers: int, num_attention_heads: int, embedding_dim: int, p_drop_emb: float = 0.1, p_drop_attn: float = 0.1, causal_attn: bool = False, num_cond_layers: int = 0, norm_module: Literal["layernorm", "rmsnorm"] = "rmsnorm", init_scheme: str = "xavier_uniform", ) -> None: r"""Create an equivariant conditional transformer regressor. Args: in_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{in}}` of the input tokens. cond_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{cond}}` of the conditioning tokens. out_rep (:class:`~escnn.group.Representation`, optional): Output representation :math:`\rho_{\text{out}}`. Defaults to ``in_rep`` if ``None``. in_horizon: Maximum length of the input sequence. cond_horizon: Maximum length of the conditioning sequence. num_layers: Number of transformer decoder layers in the main generation trunk. num_attention_heads: Number of attention heads in transformer layers. embedding_dim: Dimension of the regular representation embedding space. Must be a multiple of the group order. p_drop_emb: Dropout probability for embeddings. p_drop_attn: Dropout probability inside attention blocks. causal_attn: Whether to mask future tokens (causal masking). num_cond_layers: Number of transformer encoder layers for processing conditioning tokens. If 0, an eMLP is used instead. norm_module: Normalization layer type (``'layernorm'`` or ``'rmsnorm'``). init_scheme: Initialization scheme for equivariant layers. """ out_rep = out_rep or in_rep super().__init__(in_rep.size, out_rep.size, cond_rep.size) self.in_rep = in_rep self.out_rep = out_rep self.cond_rep = cond_rep self.in_horizon = in_horizon self.cond_horizon = cond_horizon self.num_layers = num_layers self.embedding_dim = embedding_dim self.dropout = torch.nn.Dropout(p_drop_emb) G = in_rep.group assert cond_rep.group == G == out_rep.group, "All representations must belong to the same group" if embedding_dim % G.order() != 0: raise ValueError(f"embedding_dim ({embedding_dim}) must be a multiple of the group order ({G.order()})") regular_copies = max(1, embedding_dim // G.order()) self.embedding_rep = direct_sum([G.regular_representation] * regular_copies) self.register_buffer("invariant_projector", invariant_orthogonal_projector(self.embedding_rep)) self.input_emb = symm_learning.nn.eLinear(in_rep, self.embedding_rep, bias=True, init_scheme=None) self.cond_emb = symm_learning.nn.eLinear(cond_rep, self.embedding_rep, bias=True, init_scheme=None) self.opt_time_emb = SinusoidalPosEmb(embedding_dim) self.pos_emb = torch.nn.Parameter(torch.zeros(1, in_horizon, embedding_dim)) self.cond_pos_emb = torch.nn.Parameter(torch.zeros(1, cond_horizon + 1, embedding_dim)) # Encoder parameterized as an equivariant MLP or a Transformer if num_cond_layers > 0: encoder_layer = symm_learning.models.eTransformerEncoderLayer( in_rep=self.embedding_rep, nhead=num_attention_heads, dim_feedforward=4 * embedding_dim, dropout=p_drop_attn, activation="gelu", batch_first=True, norm_first=True, # important for stability. norm_module=norm_module, # important for stability. init_scheme=None, ) logger.debug( f"Initializing {num_cond_layers} layers of eTransformerEncoderLayer of " f"{sum(p.numel() for p in encoder_layer.parameters()) / 1e6:.2f}M parameters each" ) self.encoder = torch.nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=num_cond_layers, enable_nested_tensor=False ) else: hidden_rep = direct_sum([self.embedding_rep] * 4) logger.debug(f"Initializing eMLP encoder with hidden representation {hidden_rep}") self.encoder = torch.nn.Sequential( symm_learning.nn.eLinear(in_rep=self.embedding_rep, out_rep=hidden_rep, bias=True, init_scheme=None), torch.nn.Mish(), symm_learning.nn.eLinear(in_rep=hidden_rep, out_rep=self.embedding_rep, bias=True, init_scheme=None), ) decoder_layer = symm_learning.models.eTransformerDecoderLayer( in_rep=self.embedding_rep, nhead=num_attention_heads, dim_feedforward=4 * embedding_dim, dropout=p_drop_attn, activation="gelu", batch_first=True, norm_first=True, # important for stability. norm_module=norm_module, # important for stability. init_scheme=None, ) logger.debug(f"Initializing {num_layers} layers of eTransformerDecoderLayer") self.decoder = torch.nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_layers) # Self-Attention and Cross-Attention mask. # Cross-attention is used to compute updates to the action vector based on the conditioing tokens # composed of a inference-time optimization step token and the observation conditioning tokens. if causal_attn: # causal mask to ensure that attention is only applied to the left in the input sequence # torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT # therefore, the upper triangle should be -inf and others (including diag) should be 0. mask = (torch.triu(torch.ones(in_horizon, in_horizon)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) self.register_buffer("self_att_mask", mask) t, s = torch.meshgrid(torch.arange(in_horizon), torch.arange(cond_horizon), indexing="ij") mask = t >= (s - 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) self.register_buffer("cross_att_mask", mask) else: self.self_att_mask = None self.cross_att_mask = None if norm_module == "layernorm": self.layer_norm = symm_learning.nn.eLayerNorm(self.embedding_rep, eps=1e-5, equiv_affine=True, bias=True) raise ValueError("eLayerNorm is numerically unstable. Use eRMSNorm instead for now.") else: # rmsnorm self.layer_norm = symm_learning.nn.eRMSNorm(self.embedding_rep, eps=1e-5, equiv_affine=True) self.head = symm_learning.nn.eLinear(self.embedding_rep, out_rep, bias=True, init_scheme=None) self.reset_parameters(scheme=init_scheme) trainable_mem, non_trainable_mem = module_memory(self, units="MiB") logger.info( f"[{self.__class__.__name__}]: {trainable_mem:.3f} MiB trainable, {non_trainable_mem:.3f} " f"MiB non-trainable parameters." )
[docs] @torch.no_grad() def reset_parameters(self, scheme="xavier_uniform") -> None: """Re-initialize all parameters.""" logger.debug(f"Resetting parameters of {self.__class__.__name__} with scheme: {scheme}") # Initialize eLinear layers. self.input_emb.reset_parameters(scheme=scheme) self.cond_emb.reset_parameters(scheme=scheme) self.head.reset_parameters(scheme=scheme) # Initialize final layer norm and head. self.layer_norm.reset_parameters() # Initalize conditional encoder layers. if isinstance(self.encoder, torch.nn.TransformerEncoder): for i, layer in enumerate(self.encoder.layers, start=1): assert isinstance(layer, symm_learning.models.eTransformerEncoderLayer) logger.debug(f"Resetting encoder layer {i}:[{layer.__class__.__name__}] with scheme: {scheme}") layer.reset_parameters(scheme=scheme) else: # eMLP. for i, module in enumerate(self.encoder, start=1): if isinstance(module, symm_learning.nn.eLinear): logger.debug(f"Resetting encoder module {i}:[{module.__class__.__name__}] with scheme: {scheme}") module.reset_parameters(scheme=scheme) # Initialize decoder layers. for i, layer in enumerate(self.decoder.layers, start=1): assert isinstance(layer, symm_learning.models.eTransformerDecoderLayer) logger.debug(f"Resetting decoder layer {i}:[{layer.__class__.__name__}] with scheme: {scheme}") layer.reset_parameters(scheme=scheme) logger.info(f"[{self.__class__.__name__}]: parameters initialized with `{scheme}` scheme.")
[docs] def get_optim_groups(self, weight_decay: float = 1e-3): """Todo.""" decay = set() no_decay = set() whitelist_weight_modules = (symm_learning.nn.eLinear, symm_learning.nn.eMultiheadAttention) blacklist_weight_modules = (symm_learning.nn.eLayerNorm, torch.nn.Embedding) for module_name, m in self.named_modules(): for param_name, p in m.named_parameters(): if not p.requires_grad: continue fpn = f"{module_name}.{param_name}" if module_name else param_name if param_name.endswith("bias") or param_name.startswith("bias"): no_decay.add(fpn) elif param_name.endswith("weight") and isinstance(m, whitelist_weight_modules): decay.add(fpn) elif param_name.endswith("weight") and isinstance(m, blacklist_weight_modules): no_decay.add(fpn) else: raise ValueError(f"Unrecognized parameter {fpn} in module {module_name}") no_decay.add("pos_emb") no_decay.add("cond_pos_emb") param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, f"parameters {inter_params} in both decay/no_decay" assert len(param_dict.keys() - union_params) == 0, ( f"parameters {param_dict.keys() - union_params} not separated into decay/no_decay" ) optim_groups = [ {"params": [param_dict[pn] for pn in sorted(decay)], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0}, ] return optim_groups
def configure_optimizers( # noqa: D102 self, learning_rate: float = 1e-4, weight_decay: float = 1e-3, betas: tuple[float, float] = (0.9, 0.95) ): # noqa: D102 optim_groups = self.get_optim_groups(weight_decay=weight_decay) return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
[docs] def forward( self, X: torch.Tensor, opt_step: torch.Tensor | float | int, Z: torch.Tensor, ): r"""Forward pass approximating :math:`V_k = f(X_k, Z, k)`.""" assert X.dim() == 3 and X.shape[-1] == self.in_rep.size and X.shape[1] <= self.in_horizon, ( f"Expected X shape (B, Tx<={self.in_horizon}, {self.in_rep.size}) got {X.shape}" ) assert Z.dim() == 3 and Z.shape[-1] == self.cond_rep.size and Z.shape[1] <= self.cond_horizon, ( f"Expected Z shape (B, Tz<={self.cond_horizon}, {self.cond_rep.size}) got {Z.shape}" ) batch_size = X.shape[0] # 1. Inference-time optimization step embedding (k). First conditioning token. if isinstance(opt_step, torch.Tensor): opt_steps = opt_step.to(device=X.device, dtype=torch.float32) else: opt_steps = torch.scalar_tensor(opt_step, device=X.device, dtype=torch.float32) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML opt_steps = opt_steps.reshape(-1).expand(batch_size) opt_time_emb = self.opt_time_emb(opt_steps).unsqueeze(1) # (B, 1, D) # Project time embedding onto embedding space's invariant subspace opt_time_emb = torch.einsum("ij,...j->...i", self.invariant_projector, opt_time_emb) # 2. Conditioning variable Z embedding/tokenization z_cond_emb = self.cond_emb(Z) # (B, Tz-1, D) cond_embeddings = torch.cat([opt_time_emb, z_cond_emb], dim=1) # (B, Tz, D) cond_horizon = z_cond_emb.shape[1] # (Tz) # Project time embedding onto embedding space's invariant subspace cond_pos_emb = torch.einsum( "ij,...j->...i", self.invariant_projector, self.cond_pos_emb[:, : cond_horizon + 1, :] ) # Transformer encoder of conditing tokens cond_tokens = self.dropout(cond_embeddings + cond_pos_emb) # (B, Tz, D) cond_tokens = self.encoder(cond_tokens) # (B, Tz, D) # 3. Input embedding/tokenization input_tokens = self.input_emb(X) # (B, Tx, D) # 4. Transformer encoder of input tokens with self-attention and cross-attention to cond tokens input_horizon = input_tokens.shape[1] # (Tx) # Project time embedding onto embedding space's invariant subspace pos_emb = torch.einsum("ij,...j->...i", self.invariant_projector, self.pos_emb[:, :input_horizon, :]) input_tokens = self.dropout(input_tokens + pos_emb) # (B, Tx, D) out_tokens = self.decoder( tgt=input_tokens, memory=cond_tokens, tgt_mask=self.self_att_mask, memory_mask=self.cross_att_mask ) # (B, Tx, D) # 5. Regression head projecting to output dimension. out_tokens = self.layer_norm(out_tokens) out = self.head(out_tokens) # (B, Tx, out_dim) return out
@torch.no_grad() def check_equivariance( # noqa: D102 self, batch_size: int = 10, in_len: int = 10, cond_len: int = 5, atol: float = 1e-4, rtol: float = 1e-4, ) -> None: import escnn G = self.in_rep.group in_len = min(in_len, self.in_horizon) cond_len = min(cond_len, self.cond_horizon) training_mode = self.training self.eval() def act(rep: Representation, g: escnn.group.GroupElement, x: torch.Tensor) -> torch.Tensor: mat = torch.tensor(rep(g), dtype=x.dtype, device=x.device) return torch.einsum("ij,...j->...i", mat, x) device = self.pos_emb.device dtype = self.pos_emb.dtype for _ in range(min(10, G.order())): g = random.choice(list(G.elements[1:])) # skip identity X = torch.randn(batch_size, in_len, self.in_rep.size, device=device, dtype=dtype) Z = torch.randn(batch_size, cond_len, self.cond_rep.size, device=device, dtype=dtype) k = torch.randn(batch_size, device=device, dtype=dtype) Y = self(X=X, Z=Z, opt_step=k) # Evaluate on symmetric points. gX = act(self.in_rep, g, X) gZ = act(self.cond_rep, g, Z) gY = self(X=gX, Z=gZ, opt_step=k) gY_expected = act(self.out_rep, g, Y) assert torch.allclose(gY, gY_expected, atol=atol, rtol=rtol), ( f"Equivariance test failed for group element {g}.\n" f"Max absolute difference: {torch.max(torch.abs(gY - gY_expected))}\n" ) if training_mode: self.train()
if __name__ == "__main__": from symm_learning.utils import describe_memory # Set logging to debug logging.basicConfig(level=logging.DEBUG) # Set debug message to [name][level][time][message] formatter = logging.Formatter("[%(name)s][%(levelname)s][%(asctime)s]: %(message)s") for handler in logging.getLogger().handlers: handler.setFormatter(formatter) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 from escnn.group import CyclicGroup, Icosahedral # G = Icosahedral() # # G = CyclicGroup(2) # d = 70 # # m = int(70 // G.order()) # m = 1 # in_rep = directsum([G.regular_representation] * m) # dim 8 # cond_rep = in_rep # out_rep = in_rep # Tx, Tz = 8, 6 # model = eCondTransformerRegressor( # in_rep=in_rep, # cond_rep=cond_rep, # out_rep=out_rep, # in_horizon=Tx, # cond_horizon=Tz, # num_layers=8, # num_attention_heads=1, # embedding_dim=G.order() * m, # num_cond_layers=0, # ).to(device=device, dtype=dtype) # model.check_equivariance() # G = Icosahedral() G = CyclicGroup(2) m = 5 embedding_dim = G.order() * m * 4 in_rep = direct_sum([G.regular_representation] * m) # dim 8 print(in_rep.size) cond_rep = in_rep out_rep = in_rep Tx, Tz = 8, 6 start_time = time.time() model = eCondTransformerRegressor( in_rep=in_rep, cond_rep=cond_rep, out_rep=out_rep, in_horizon=Tx, cond_horizon=Tz, num_layers=5, num_attention_heads=1, embedding_dim=embedding_dim, num_cond_layers=0, ) # print(describe_memory("eCondTransformer", model)) print(f"Model initialized in {time.time() - start_time:.2f} seconds") model.to(device=device, dtype=dtype) model.eval() model.check_equivariance(atol=1e-2, rtol=1e-2) print("Equivariance test passed!")