Source code for symm_learning.models.diffusion.cond_eunet1d

from __future__ import annotations

from math import ceil
from typing import Iterable

import torch
from escnn.group import Representation

from symm_learning.models.diffusion.cond_unet1d import SinusoidalPosEmb
from symm_learning.nn import IrrepSubspaceNormPooling, eAffine, eConv1d, eConvTranspose1d, eRMSNorm
from symm_learning.nn.module import eModule
from symm_learning.representation_theory import direct_sum


class _eChannelRMSNorm(eModule):
    """Apply eRMSNorm over channels for tensors shaped (B, C, L)."""

    def __init__(self, rep: Representation, eps: float = 1e-6):
        super().__init__()
        self.in_rep = rep
        self.out_rep = rep
        self.norm = eRMSNorm(rep, eps=eps, equiv_affine=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = x.permute(0, 2, 1)  # (B, L, C)
        y = self.norm(y)
        return y.permute(0, 2, 1)


class eConditionalResidualBlock1D(eModule):
    r"""Channel-equivariant conditional residual block with FiLM modulation.

    This block applies two equivariant convolutional layers with a residual
    connection. The output of the first convolutional block is modulated by a
    conditioning tensor using an equivariant FiLM (Feature-wise Linear Modulation)
    layer. The FiLM layer applies an equivariant affine transformation, where the
    scale and bias parameters are predicted by an invariant neural network from the
    conditioning tensor.

    The scale transformation is applied in the spectral domain, with a separate
    learnable parameter for each irreducible representation (irrep). The bias
    is applied only to the invariant (trivial) subspaces of the representation.
    See details in :class:`~symm_learning.nn.eAffine`.

    Functional constraint:

    .. math::
        f(\rho_{\mathcal{X}}(g)\mathbf{x},\,\rho_{\mathcal{C}}(g)\mathbf{c})
        = \rho_{\mathcal{Y}}(g)\,f(\mathbf{x},\mathbf{c}),
        \quad \forall g\in\mathbb{G}.
    """

    def __init__(
        self,
        in_rep: Representation,
        out_rep: Representation,
        cond_rep: Representation,
        kernel_size: int = 3,
        cond_predict_scale: bool = True,
        activation: torch.nn.Module = torch.nn.ReLU(),
        normalize: bool = True,
        init_scheme: str | None = "xavier_uniform",
    ):
        r"""Initialize the conditional residual block.

        Args:
            in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}` acting on the
                channel axis.
            out_rep (:class:`~escnn.group.Representation`): Output representation :math:`\rho_{\text{out}}` for the
                convolutions and residual.
            cond_rep (:class:`~escnn.group.Representation`): Representation :math:`\rho_{\text{cond}}` of the
                conditioning vector used to predict FiLM parameters.
            kernel_size (int, optional): Spatial kernel size for the equivariant convolutions. Defaults to 3.
            cond_predict_scale (bool, optional): Whether the FiLM encoder predicts scale parameters.
                Currently must be ``True``. Defaults to True.
            activation (torch.nn.Module, optional): Non-linearity applied after each normalization. Defaults to ``ReLU``
            normalize (bool, optional): If ``True``, apply channel-wise equivariant RMS normalization. Defaults to True.
            init_scheme (str | None, optional): Weight initialization scheme for convolutions.
                Defaults to ``\"xavier_uniform\"``.
        """
        super().__init__()
        self.in_rep, self.out_rep, self.cond_rep = in_rep, out_rep, cond_rep
        if not cond_predict_scale:
            raise NotImplementedError("Currently only cond_predict_scale=True is supported.")

        self.conv1 = eConv1d(in_rep, out_rep, kernel_size, padding=kernel_size // 2, init_scheme=init_scheme)
        self.conv2 = eConv1d(out_rep, out_rep, kernel_size, padding=kernel_size // 2, init_scheme=init_scheme)

        self.norm1 = _eChannelRMSNorm(out_rep) if normalize else torch.nn.Identity()
        self.norm2 = _eChannelRMSNorm(out_rep) if normalize else torch.nn.Identity()
        self.act = activation

        self.affine = eAffine(in_rep=out_rep, learnable=False)
        self.film_dims = self.affine.num_scale_dof + self.affine.num_bias_dof

        # Invariant encoder of FiLM modulation parameters
        self.cond_encoder = torch.nn.Sequential(
            IrrepSubspaceNormPooling(in_rep=cond_rep),
            torch.nn.Linear(in_features=len(cond_rep.irreps), out_features=self.film_dims),
            torch.nn.Mish(),
        )

        self.residual_conv = (  # Final conv for residual addition if needed.
            eConv1d(in_rep, out_rep, 1, init_scheme=init_scheme) if in_rep != out_rep else torch.nn.Identity()
        )

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """Apply two equivariant conv blocks with FiLM modulation and a residual connection.

        Args:
            x (:class:`~torch.Tensor`): Input of shape ``(B, in_rep.size, L)`` where ``B`` is batch and ``L`` is
                sequence length.
            cond (:class:`~torch.Tensor`): Conditioning tensor of shape ``(B, cond_rep.size)`` used to predict FiLM
                parameters.

        Returns:
            :class:`~torch.Tensor`: Output tensor of shape ``(B, out_rep.size, L)`` after modulation and residual
                addition.
        """
        assert x.shape[1] == self.in_rep.size, f"Expected channel dim {self.in_rep.size}, got {x.shape}"
        assert cond.shape[-1] == self.cond_rep.size, f"Cond dim mismatch {cond.shape}"

        B, _, L = x.shape

        out = self.conv1(x)  # (B, C, L)
        out = self.act(self.norm1(out))

        if self.film_dims > 0:
            film_params = self.cond_encoder(cond)  # (B, film_dims)
            scale_dof = torch.broadcast_to(
                film_params[:, None, : self.affine.num_scale_dof], (B, L, self.affine.num_scale_dof)
            )  # (B, num_scale_dof) -> (B, L, num_scale_dof)
            bias_dof = torch.broadcast_to(
                film_params[:, None, self.affine.num_scale_dof :], (B, L, self.affine.num_bias_dof)
            )  # (B, num_bias_dof) -> (B, L, num_bias_dof)
            out = self.affine(
                out.permute(0, 2, 1),  # (B, C, L) -> (B, L, C)
                scale_dof=scale_dof,
                bias_dof=bias_dof,
            ).permute(0, 2, 1)  # (B, L, C) -> (B, C, L)

        out = self.conv2(out)
        out = self.act(self.norm2(out))

        res = self.residual_conv(x) if self.residual_conv is not None else x
        return out + res

    @torch.no_grad()
    def check_equivariance(self, atol=1e-5, rtol=1e-5):  # noqa: D102
        G = self.in_rep.group
        B, L = 10, 30
        device, dtype = next(self.parameters()).device, next(self.parameters()).dtype
        x = torch.randn(B, self.in_rep.size, L, device=device, dtype=dtype)
        z = torch.randn(B, self.cond_rep.size, device=device, dtype=dtype)
        y = self(x, z)

        for _ in range(10):
            g = G.sample()
            rho_in = torch.tensor(self.in_rep(g), dtype=x.dtype, device=x.device)
            rho_out = torch.tensor(self.out_rep(g), dtype=y.dtype, device=y.device)
            rho_cond = torch.tensor(self.cond_rep(g), dtype=z.dtype, device=z.device)
            gx = torch.einsum("ij,bjl->bil", rho_in, x)
            gz = torch.einsum("ij,bj->bi", rho_cond, z)
            y_expected = self(gx, gz)
            gy = torch.einsum("ij,bjl->bil", rho_out, y)
            assert torch.allclose(gy, y_expected, atol=atol, rtol=rtol), (
                f"Equivariance failed for group element {g} with max error {(gy - y_expected).abs().max().item():.3e}"
            )


[docs] class eConditionalUnet1D(eModule): r"""Equivariant U-Net for 1D signals with global conditioning and FiLM. The model defines: .. math:: \mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{L} \times \mathcal{Z}_{\mathrm{local}}^{L} \times \mathcal{Z}_{\mathrm{global}} \times \mathbb{R} \to \mathcal{X}^{L}. Functional equivariance constraint: .. math:: \mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x},\,\rho_{\mathcal{Z}}(g)\mathbf{z},\,t) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{x},\mathbf{z},t), \quad \forall g\in\mathbb{G}. """ def __init__( self, in_rep: Representation, local_cond_rep: Representation | None, global_cond_rep: Representation | None = None, diffusion_step_embed_dim: int = 256, down_dims: Iterable[int] = (256, 512, 1024), kernel_size: int = 3, cond_predict_scale: bool = True, activation: torch.nn.Module = torch.nn.ReLU(), normalize: bool = True, downsample: str = "stride", init_scheme: str | None = "xavier_uniform", ): super().__init__() assert downsample in {"stride", "pooling"}, "downsample must be 'stride' or 'pooling'" self.in_rep = self.out_rep = in_rep self.global_cond_rep = global_cond_rep self.local_cond_rep = local_cond_rep self.downsample = downsample G = in_rep.group reg_rep = G.regular_representation trivial_rep = G.trivial_representation down_dims = list(down_dims) # all_dims = [in_rep.size] + down_dims reps = [in_rep] + [direct_sum([reg_rep] * ceil(d / reg_rep.size)) for d in down_dims] diffusion_step_encoder = torch.nn.Sequential( SinusoidalPosEmb(diffusion_step_embed_dim), torch.nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4), torch.nn.Mish(), torch.nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim), ) if global_cond_rep is not None: cond_rep = direct_sum([global_cond_rep, direct_sum([trivial_rep] * diffusion_step_embed_dim)]) else: cond_rep = direct_sum([trivial_rep] * diffusion_step_embed_dim) self.cond_rep = cond_rep self.diffusion_step_encoder = diffusion_step_encoder block_kwargs = dict( kernel_size=kernel_size, cond_predict_scale=cond_predict_scale, activation=activation, normalize=normalize, init_scheme=init_scheme, ) conv_kwargs = dict(bias=True, init_scheme=init_scheme) in_out_reps = list(zip(reps[:-1], reps[1:])) local_cond_encoder = None if local_cond_rep is not None: first_out_rep = in_out_reps[0][1] local_cond_encoder = torch.nn.ModuleList( [ eConditionalResidualBlock1D( in_rep=local_cond_rep, out_rep=first_out_rep, cond_rep=cond_rep, **block_kwargs, ), eConditionalResidualBlock1D( in_rep=local_cond_rep, out_rep=first_out_rep, cond_rep=cond_rep, **block_kwargs, ), ] ) self.local_cond_encoder = local_cond_encoder mid_rep = reps[-1] self.mid_modules = torch.nn.ModuleList( [ eConditionalResidualBlock1D( mid_rep, mid_rep, cond_rep=cond_rep, **block_kwargs, ), eConditionalResidualBlock1D( mid_rep, mid_rep, cond_rep=cond_rep, **block_kwargs, ), ] ) down_modules = torch.nn.ModuleList() for idx, (rep_in, rep_out) in enumerate(in_out_reps): is_last = idx == len(in_out_reps) - 1 down_modules.append( torch.nn.ModuleList( [ eConditionalResidualBlock1D( rep_in, rep_out, cond_rep=cond_rep, **block_kwargs, ), eConditionalResidualBlock1D( rep_out, rep_out, cond_rep=cond_rep, **block_kwargs, ), eConv1d( rep_out, rep_out, kernel_size=3, stride=2 if not is_last and downsample == "stride" else 1, padding=1, **conv_kwargs, ) if (downsample == "stride" and not is_last) else ( torch.nn.MaxPool1d(kernel_size=2, stride=2) if downsample == "pooling" and not is_last else torch.nn.Identity() ), ] ) ) self.down_modules = down_modules up_modules = torch.nn.ModuleList() current_rep = in_out_reps[-1][1] # deepest representation up_pairs = list(reversed(in_out_reps[1:])) # skip shallowest pair for idx, (target_rep, skip_rep) in enumerate((pair[0], pair[1]) for pair in up_pairs): is_last = idx == len(up_pairs) - 1 upsample_in_rep = direct_sum([current_rep, skip_rep]) # concat current + skip out_rep = target_rep up_modules.append( torch.nn.ModuleList( [ eConditionalResidualBlock1D( upsample_in_rep, out_rep, cond_rep=cond_rep, **block_kwargs, ), eConditionalResidualBlock1D( out_rep, out_rep, cond_rep=cond_rep, **block_kwargs, ), eConvTranspose1d( out_rep, out_rep, kernel_size=3, stride=2 if not is_last and downsample == "stride" else 1, padding=1, **conv_kwargs, ) if (downsample == "stride" and not is_last) else ( torch.nn.Upsample(scale_factor=2, mode="nearest") if downsample == "pooling" and not is_last else torch.nn.Identity() ), ] ) ) current_rep = out_rep self.up_modules = up_modules first_rep = in_out_reps[0][1] final_norm = _eChannelRMSNorm(first_rep) if normalize else torch.nn.Identity() self.final_conv = torch.nn.Sequential( eConv1d(first_rep, first_rep, kernel_size=kernel_size, padding=kernel_size // 2, **conv_kwargs), final_norm, activation, eConv1d(first_rep, in_rep, kernel_size=1, padding=0, **conv_kwargs), )
[docs] def forward( self, sample: torch.Tensor, timestep: torch.Tensor | float | int, local_cond: torch.Tensor | None = None, film_cond: torch.Tensor | None = None, ) -> torch.Tensor: """Run a forward pass of the equivariant U-Net. Args: sample (:class:`~torch.Tensor`): Input signal shaped ``(B, in_rep.size, L)``. timestep (torch.Tensor | float | int): Diffusion step; scalar or batch, broadcast to ``B``. local_cond (torch.Tensor | None, optional): Local conditioning signal shaped ``(B, local_cond_rep.size, L)`` when provided. film_cond (torch.Tensor | None, optional): Global conditioning vector shaped ``(B, global_cond_rep.size)`` to drive FiLM. Returns: :class:`~torch.Tensor`: Output tensor shaped ``(B, in_rep.size, L)``. """ assert sample.shape[1] == self.in_rep.size, f"Expected channels {self.in_rep.size}, got {sample.shape}" device = sample.device t = timestep if not torch.is_tensor(t): t = torch.tensor([timestep], dtype=torch.long, device=device) elif t.ndim == 0: t = t[None].to(device) t = t.expand(sample.shape[0]) film_diff = self.diffusion_step_encoder(t) film_feature = torch.cat([film_cond, film_diff], dim=-1) if film_cond is not None else film_diff assert film_feature.shape[-1] == self.cond_rep.size h_local = [] if self.local_cond_encoder is not None and local_cond is not None: resnet_a, resnet_b = self.local_cond_encoder x_loc = resnet_a(local_cond, film_feature) h_local.append(x_loc) x_loc = resnet_b(local_cond, film_feature) h_local.append(x_loc) x = sample skips = [] for idx, (res1, res2, down) in enumerate(self.down_modules): x = res1(x, film_feature) if idx == 0 and h_local: x = x + h_local[0] x = res2(x, film_feature) skips.append(x) x = down(x) for mid in self.mid_modules: x = mid(x, film_feature) for idx, (res1, res2, up) in enumerate(self.up_modules): skip = skips.pop() x = torch.cat([x, skip], dim=1) x = res1(x, film_feature) if idx == len(self.up_modules) - 1 and h_local: x = x + h_local[1] x = res2(x, film_feature) x = up(x) x = self.final_conv(x) return x
[docs] @torch.no_grad() def check_equivariance(self, batch_size=3, length=5, atol=1e-5, rtol=1e-5): # noqa: D102 """Check equivariance under channel actions of the underlying fiber group.""" G = self.in_rep.group device = next(self.parameters()).device dtype = next(self.parameters()).dtype x = torch.randn(batch_size, self.in_rep.size, length, device=device, dtype=dtype) local_cond = ( torch.randn(batch_size, self.local_cond_rep.size, length, device=device, dtype=dtype) if self.local_cond_rep is not None else None ) global_cond = ( torch.randn(batch_size, self.global_cond_rep.size, device=device, dtype=dtype) if self.global_cond_rep is not None else None ) t = torch.tensor(0, dtype=torch.long, device=device) y = self(x, timestep=t, local_cond=local_cond, film_cond=global_cond) for _ in range(10): g = G.sample() rho_in = torch.tensor(self.in_rep(g), dtype=dtype, device=device) gx = torch.einsum("ij,bjl->bil", rho_in, x) g_local = None if local_cond is not None: rho_local = torch.tensor(self.local_cond_rep(g), dtype=dtype, device=device) g_local = torch.einsum("ij,bjl->bil", rho_local, local_cond) g_global = None if global_cond is not None: rho_global = torch.tensor(self.global_cond_rep(g), dtype=dtype, device=device) g_global = torch.einsum("ij,bj->bi", rho_global, global_cond) y_expected = self(gx, timestep=t, local_cond=g_local, film_cond=g_global) rho_out = torch.tensor(self.out_rep(g), dtype=y.dtype, device=y.device) gy = torch.einsum("ij,bjl->bil", rho_out, y) assert torch.allclose(gy, y_expected, atol=atol, rtol=rtol), ( f"Equivariance failed for group element {g} with max error {(gy - y_expected).abs().max().item():.3e}" )
if __name__ == "__main__": # Example usage from escnn.group import CyclicGroup diffusion_step_embed_dim = 32 G = CyclicGroup(2) mx, my = 1, 2 in_rep = direct_sum([G.regular_representation] * mx) out_rep = direct_sum([G.regular_representation] * my) cond_rep = direct_sum([G.regular_representation] * 2) print("Testing eConditionalResidualBlock1D ------------------------------------------") res_block = eConditionalResidualBlock1D( in_rep=in_rep, out_rep=out_rep, cond_rep=cond_rep, kernel_size=3, cond_predict_scale=True, ) res_block.check_equivariance(atol=1e-5, rtol=1e-5) print("\nTesting eConditionalUnet1D ----------------------------------------------------") model = eConditionalUnet1D( in_rep=in_rep, local_cond_rep=None, global_cond_rep=cond_rep, diffusion_step_embed_dim=16, down_dims=[64, 128, 256], kernel_size=3, cond_predict_scale=True, ) model.check_equivariance(atol=1e-5, rtol=1e-5)