eConditionalUnet1D#
- class eConditionalUnet1D(in_rep, local_cond_rep, global_cond_rep=None, diffusion_step_embed_dim=256, down_dims=(256, 512, 1024), kernel_size=3, cond_predict_scale=True, activation=ReLU(), normalize=True, downsample='stride', init_scheme='xavier_uniform')[source]#
Bases:
eModuleEquivariant U-Net for 1D signals with global conditioning and FiLM.
The model defines:
\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{L} \times \mathcal{Z}_{\mathrm{local}}^{L} \times \mathcal{Z}_{\mathrm{global}} \times \mathbb{R} \to \mathcal{X}^{L}.\]Functional equivariance constraint:
\[\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x},\,\rho_{\mathcal{Z}}(g)\mathbf{z},\,t) = \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{x},\mathbf{z},t), \quad \forall g\in\mathbb{G}.\]- Parameters:
- check_equivariance(batch_size=3, length=5, atol=1e-05, rtol=1e-05)[source]#
Check equivariance under channel actions of the underlying fiber group.
- forward(sample, timestep, local_cond=None, film_cond=None)[source]#
Run a forward pass of the equivariant U-Net.
- Parameters:
sample (
Tensor) – Input signal shaped(B, in_rep.size, L).timestep (torch.Tensor | float | int) – Diffusion step; scalar or batch, broadcast to
B.local_cond (torch.Tensor | None, optional) – Local conditioning signal shaped
(B, local_cond_rep.size, L)when provided.film_cond (torch.Tensor | None, optional) – Global conditioning vector shaped
(B, global_cond_rep.size)to drive FiLM.
- Returns:
Output tensor shaped
(B, in_rep.size, L).- Return type: