eConditionalUnet1D#

class eConditionalUnet1D(in_rep, local_cond_rep, global_cond_rep=None, diffusion_step_embed_dim=256, down_dims=(256, 512, 1024), kernel_size=3, cond_predict_scale=True, activation=ReLU(), normalize=True, downsample='stride', init_scheme='xavier_uniform')[source]#

Bases: eModule

Equivariant U-Net for 1D signals with global conditioning and FiLM.

The model defines:

\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{L} \times \mathcal{Z}_{\mathrm{local}}^{L} \times \mathcal{Z}_{\mathrm{global}} \times \mathbb{R} \to \mathcal{X}^{L}.\]

Functional equivariance constraint:

\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x},\,\rho_{\mathcal{Z}}(g)\mathbf{z},\,t) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{x},\mathbf{z},t), \quad \forall g\in\mathbb{G}.\]
Parameters:
  • in_rep (Representation)

  • local_cond_rep (Representation | None)

  • global_cond_rep (Representation | None)

  • diffusion_step_embed_dim (int)

  • down_dims (Iterable[int])

  • kernel_size (int)

  • cond_predict_scale (bool)

  • activation (torch.nn.Module)

  • normalize (bool)

  • downsample (str)

  • init_scheme (str | None)

check_equivariance(batch_size=3, length=5, atol=1e-05, rtol=1e-05)[source]#

Check equivariance under channel actions of the underlying fiber group.

forward(sample, timestep, local_cond=None, film_cond=None)[source]#

Run a forward pass of the equivariant U-Net.

Parameters:
  • sample (Tensor) – Input signal shaped (B, in_rep.size, L).

  • timestep (torch.Tensor | float | int) – Diffusion step; scalar or batch, broadcast to B.

  • local_cond (torch.Tensor | None, optional) – Local conditioning signal shaped (B, local_cond_rep.size, L) when provided.

  • film_cond (torch.Tensor | None, optional) – Global conditioning vector shaped (B, global_cond_rep.size) to drive FiLM.

Returns:

Output tensor shaped (B, in_rep.size, L).

Return type:

Tensor