Source code for symm_learning.models.diffusion.cond_unet1d

#  Code Taken from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/conditional_unet1d.py

from __future__ import annotations

import logging
import math
from typing import Union

import torch
import torch.nn as nn


class UnsqueezeLast(nn.Module):
    """Append a singleton channel dimension."""

    def forward(self, x):  # noqa: D102
        return x.unsqueeze(-1)


logger = logging.getLogger(__name__)


[docs] class ConditionalUnet1D(nn.Module): r"""A 1D/Time U-Net for predicting conditional score/velocity fields. The model defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathbb{R}^{d_x \times L} \times \mathbb{R}^{d_{\mathrm{local}}} \times \mathbb{R}^{d_{\mathrm{global}}} \times \mathbb{R} \to \mathbb{R}^{d_x \times L}. It can be used to parameterize score-like or velocity-like targets in conditional diffusion/flow pipelines, e.g. :math:`\nabla \log \mathbb{P}(Y\mid X)`. This baseline model is unconstrained with respect to group actions (no explicit equivariance/invariance constraints are imposed). The influence of x in the diffusion process is captured via `local` and `global` conditioning of the Unet architecture. Local conditioning: Provided a local conditioning encoder `z(x)`, the output of the encoder is concatenated to the input of the Unet architecture. Global conditioning: Provided a global conditioning vector `c = b(x)`, the output of the encoder is used to modulate the convolutional layers of the Unet architecture via Feature-Wise Linear Modulation (FiLM) modulation. Args: input_dim (:class:`int`): The dimension of the input data. local_cond_dim (:class:`int`, optional): The dimension of the local conditioning vector. Defaults to None. global_cond_dim (:class:`int`, optional): The dimension of the global conditioning vector. Defaults to None. diffusion_step_embed_dim (:class:`int`, optional): The dimension of the diffusion step embedding. Defaults to 256. down_dims (:class:`list`, optional): A list of dimensions for the downsampling path. Defaults to [256, 512, 1024]. kernel_size (:class:`int`, optional): The size of the convolutional kernel. Defaults to 3. n_groups (:class:`int`, optional): The number of groups for GroupNorm. Defaults to 8. cond_predict_scale (:class:`bool`, optional): Whether to predict the scale for conditioning. Defaults to False. """ def __init__( self, input_dim, local_cond_dim: int = None, global_cond_dim: int = None, diffusion_step_embed_dim=16, down_dims=[32, 64], kernel_size=3, n_groups=1, cond_predict_scale=True, ): super().__init__() # Calculate effective input dimension considering local conditioning concatenation effective_input_dim = input_dim if local_cond_dim is not None: effective_input_dim = input_dim + local_cond_dim all_dims = [effective_input_dim] + list(down_dims) diffusion_step_encoder = nn.Sequential( SinusoidalPosEmb(diffusion_step_embed_dim), nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4), nn.Mish(), nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim), ) cond_dim = diffusion_step_embed_dim if global_cond_dim is not None: cond_dim += global_cond_dim in_out = list(zip(all_dims[:-1], all_dims[1:])) mid_dim = all_dims[-1] self.mid_modules = nn.ModuleList( [ ConditionalResidualBlock1D( mid_dim, mid_dim, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ), ConditionalResidualBlock1D( mid_dim, mid_dim, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ), ] ) down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) down_modules.append( nn.ModuleList( [ ConditionalResidualBlock1D( dim_in, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ), ConditionalResidualBlock1D( dim_out, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ), Downsample1d(dim_out) if not is_last else nn.Identity(), ] ) ) up_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind >= (len(in_out) - 1) up_modules.append( nn.ModuleList( [ ConditionalResidualBlock1D( dim_out * 2, dim_in, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ), ConditionalResidualBlock1D( dim_in, dim_in, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ), Upsample1d(dim_in) if not is_last else nn.Identity(), ] ) ) final_conv = nn.Sequential( Conv1dBlock(effective_input_dim, effective_input_dim, kernel_size=kernel_size, n_groups=n_groups), nn.Conv1d(effective_input_dim, input_dim, 1), ) self.input_dim = input_dim self.local_cond_dim = local_cond_dim self.diffusion_step_encoder = diffusion_step_encoder self.up_modules = up_modules self.down_modules = down_modules self.final_conv = final_conv logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
[docs] def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], local_cond: torch.Tensor = None, film_cond: torch.Tensor = None, **kwargs, ): """Forward pass of the Conditional Unet 1D model. Args: sample (:class:`~torch.Tensor`): The input tensor of shape (B, input_dim, T). timestep (:class:`~torch.Tensor` | :class:`float` | :class:`int`): The diffusion timestep. local_cond (:class:`~torch.Tensor`, optional): The local conditioning tensor of shape (B, local_cond_dim). Defaults to None. film_cond (:class:`~torch.Tensor`, optional): The global conditioning tensor of shape (B, film_cond_dim). Defaults to None. kwargs: Additional keyword arguments reserved for API compatibility. Returns: :class:`~torch.Tensor`: The output tensor of shape (B, input_dim, T). """ # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) film_diff_step_features = self.diffusion_step_encoder(timesteps) if film_cond is not None: film_features = torch.cat([film_cond, film_diff_step_features], axis=-1) else: film_features = film_diff_step_features # Handle local conditioning by concatenation at the input of the UNet if self.local_cond_dim is not None: assert local_cond is not None and local_cond.shape == (sample.shape[0], self.local_cond_dim), ( f"local_cond does not match expected {(sample.shape[0], self.local_cond_dim)}" ) # Expand local_cond to match time dimension of sample (B, local_cond_dim) -> (B, local_cond_dim, T) local_cond_expanded = local_cond.unsqueeze(-1).expand(-1, -1, sample.shape[-1]) x = torch.cat([sample, local_cond_expanded], dim=1) else: x = sample h = [] for resnet, resnet2, downsample in self.down_modules: x = resnet(x, film_features) x = resnet2(x, film_features) h.append(x) x = downsample(x) for mid_module in self.mid_modules: x = mid_module(x, film_features) for resnet, resnet2, upsample in self.up_modules: x = torch.cat((x, h.pop()), dim=1) x = resnet(x, film_features) x = resnet2(x, film_features) x = upsample(x) x = self.final_conv(x) return x
class SinusoidalPosEmb(nn.Module): """Sinusoidal positional embedding layer. This layer encodes a scalar input (e.g., a diffusion timestep) into a high-dimensional vector using a combination of sine and cosine functions of varying frequencies. This technique, introduced in the "Attention Is All You Need" paper, allows the model to easily attend to relative positions and is effective for representing periodic or sequential data. The embedding is calculated as follows: emb(x, 2i) = sin(x / 10000^(2i/dim)) emb(x, 2i+1) = cos(x / 10000^(2i/dim)) where `x` is the input scalar, `dim` is the embedding dimension, and `i` is the channel index. The `forward` method implements this by first calculating the frequency term `1 / 10000^(2i/dim)` and then multiplying the input `x` by these frequencies. This creates the argument for the sine and cosine functions, effectively encoding the position `x` across the embedding dimension. Args: dim (:class:`int`): The dimension of the embedding. """ def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): # noqa: D102 device = x.device half_dim = self.dim // 2 emb = math.log(10000) / max((half_dim - 1), 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class Downsample1d(nn.Module): """Downsampling layer for 1D data. Args: dim (:class:`int`): The number of input and output channels. """ def __init__(self, dim): super().__init__() self.conv = nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=3, stride=2, padding=1) def forward(self, x): # noqa: D102 return self.conv(x) class Upsample1d(nn.Module): """Upsampling layer for 1D data. Args: dim (:class:`int`): The number of input and output channels. """ def __init__(self, dim): super().__init__() self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) def forward(self, x): # noqa: D102 return self.conv(x) class Conv1dBlock(nn.Module): """A 1D convolutional block with GroupNorm and Mish activation. Args: inp_channels (:class:`int`): The number of input channels. out_channels (:class:`int`): The number of output channels. kernel_size (:class:`int`): The size of the convolutional kernel. n_groups (int, optional): The number of groups for GroupNorm. Defaults to 8. """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.GroupNorm(n_groups, out_channels), nn.Mish(), ) def forward(self, x): # noqa: D102 return self.block(x) class ConditionalResidualBlock1D(nn.Module): """Conditional residual block for 1D convolution with FiLM modulation. This block applies two 1D convolutional blocks with conditional modulation between them. The conditioning is applied via Feature-wise Linear Modulation (FiLM) which can either predict scale and bias parameters or just bias, depending on the `cond_predict_scale` flag. Args: in_channels (:class:`int`): Number of input channels. out_channels (:class:`int`): Number of output channels. cond_dim (:class:`int`): Dimension of the conditioning vector. kernel_size (:class:`int`, optional): Size of the convolutional kernel. Defaults to 3. n_groups (:class:`int`, optional): Number of groups for GroupNorm. Defaults to 8. cond_predict_scale (:class:`bool`, optional): If True, conditioning predicts both scale and bias. If False, conditioning only predicts bias. Defaults to False. """ def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False): super().__init__() self.blocks = nn.ModuleList( [ Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), ] ) # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias cond_channels = out_channels if cond_predict_scale: cond_channels = out_channels * 2 self.cond_predict_scale = cond_predict_scale self.out_channels = out_channels self.cond_encoder = nn.Sequential(nn.Linear(cond_dim, cond_channels), nn.Mish(), UnsqueezeLast()) # make sure dimensions compatible self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() def forward(self, x, cond): """Forward pass of the conditional residual block. Args: x (:class:`~torch.Tensor`): Input tensor of shape (batch_size, in_channels, horizon). cond (:class:`~torch.Tensor`): Conditioning vector of shape (batch_size, cond_dim). Returns: :class:`~torch.Tensor`: Output tensor of shape (batch_size, out_channels, horizon). """ out = self.blocks[0](x) embed = self.cond_encoder(cond) if self.cond_predict_scale: embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) scale = embed[:, 0, ...] bias = embed[:, 1, ...] out = scale * out + bias else: out = out + embed out = self.blocks[1](out) out = out + self.residual_conv(x) return out if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Test parameters batch_size = 512 input_dim = 99 local_cond_dim = 10 film_cond_dim = 5 time_steps = 200 # Create model model = ConditionalUnet1D( input_dim=input_dim, local_cond_dim=local_cond_dim, global_cond_dim=film_cond_dim, diffusion_step_embed_dim=16, down_dims=[32, 64], kernel_size=3, n_groups=1, cond_predict_scale=True, ).to(device) # Create test inputs sample = torch.randn(batch_size, input_dim, time_steps, device=device, requires_grad=True) timestep = torch.randint(0, 1000, (batch_size,), device=device) local_cond = torch.randn(batch_size, local_cond_dim, device=device) film_cond = torch.randn(batch_size, film_cond_dim, device=device) print("\n" + "=" * 60) print("TEST 1: Model with LOCAL + GLOBAL conditioning") print("=" * 60) # Create model model = ConditionalUnet1D( input_dim=input_dim, local_cond_dim=local_cond_dim, global_cond_dim=film_cond_dim, diffusion_step_embed_dim=16, down_dims=[32, 64], kernel_size=3, n_groups=1, cond_predict_scale=True, ).to(device) # Forward pass output = model(sample=sample, timestep=timestep, local_cond=local_cond, film_cond=film_cond) print(f"Output shape: {output.shape}")