GenCondRegressor#
- class GenCondRegressor(in_dim, out_dim, cond_dim)[source]#
-
Generative Conditional Regressor module.
This is an abstract module inteded to be used as the backbone of a conditional flow-matching/diffusion process which enables sampling from the conditional probability distribution:
\[\mathbb{P}(X \mid Z)\]Let \(\mathcal{X}=\mathbb{R}^{d_x}\), \(\mathcal{Z}=\mathbb{R}^{d_z}\), and \(\mathcal{Y}=\mathbb{R}^{d_v}\). Where \(X = [x_0,\ldots,x_{T_x}] \in \mathcal{X}^{T_x}\) is the input/data sample composed of a trajectory of \(T_x\) points, and \(Z = [z_0,\ldots,z_{T_z}] \in \mathcal{Z}^{T_z}\) is the conditioning/observation variable composed of \(T_z\) points.
The module parameterizes a conditional vector-valued regression map:
\[\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_x} \times \mathcal{Z}^{T_z} \times \mathbb{R} \to \mathcal{Y}^{T_x},\]with
\[V_k = \mathbf{f}_{\mathbf{\theta}}(X_k, Z, k).\]Where \(k\) denotes the inference-time optimization timestep (i.e., the step of the flow-matching/diffusion) process, \(X_k\) is the noisy version of the data sample at step k, and \(V_k \in (\mathbb{R}^{d_v})^{T_x}\) is the target regression vector-valued variable. For diffusion models \(V_k\) typically corresponds to the score functional of \(\mathbb{P}_k(X \mid Z)\), while for flow-matching models it typically corresponds to the flow-matching velocity vector field.
This abstract base class does not impose equivariance/invariance constraints by itself.
- abstractmethod forward(X, opt_step, Z)[source]#
Forward pass of the generative conditional regressor.
- Parameters:
X (
Tensor) – The input/data sample composed of a trajectory of T_x points in a d_x-dimensional space. Shape: (B, T_x, d_x), where B is the batch size.opt_step (
Tensor|float|int) – The optimization step(s) k at which to evaluate the regressor. Can be a single scalar or a tensor of shape (B,).Z (
Tensor) – The conditioning/observation variable composed of T_z points in a d_z-dimensional space. Shape: (B, T_z, d_z), where B is the batch size.
- Returns:
The output regression variable of shape (B, T_x, d_v).
- Return type: