System Implementation

Architecture, training loop, and inference pipeline.

Architecture Overview

Model Core

Standard UNet architecture modified for vector fields (2-channel input/output). Uses sinusoidal positional embeddings and residual blocks.

Scheduler

Denoising Diffusion Probabilistic Models (DDPM) scheduler for noise addition and removal steps.

Data Pipeline

Handles HDF5 datasets containing vector fields. Supports masking for inpainting tasks during training and inference.

Training Pipeline

training.py
  1. 1

    Data Preparation

    Load batch of clean vector fields $x_0$ and create random binary masks $m$.

  2. 2

    Forward Diffusion

    Sample random timestep $t$ and noise $\epsilon$. Add noise to the image according to the schedule:

    x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon
  3. 3

    Prediction & Loss

    Model predicts noise $\epsilon_\theta(x_t, t, m)$. Calculate MSE loss between predicted and actual noise.

    Note: Physics loss can also be added here by predicting $\hat{x}_0$ and checking divergences, though typically handled at inference.

Inference Pipeline

inference.py
def run_inference(model, mask, steps=50):
    # 1. Initialize noisy latent
    x_t = torch.randn_like(mask)

    # 2. Iterative Denoising Loop
    for t in reversed(range(steps)):
        
        # A. Predict Noise
        noise_pred = model(x_t, t, mask)

        # B. Classifier-Free Guidance (if enabled)
        # noise_pred = w * cond + (1-w) * uncond

        # C. Physics Guidance (The "Autograd" trick)
        if use_physics:
            # Enable grad on x_t
            x_t.requires_grad_(True)
            
            # Estimate x0 (Tweedie)
            x0_hat = tweedie_formula(x_t, noise_pred)
            
            # Compute Phys Loss
            loss = divergence(x0_hat) + curl(x0_hat)
            
            # Update x_t
            grad = torch.autograd.grad(loss, x_t)
            x_t = x_t - eta * grad
            
            # Detach for next step
            x_t = x_t.detach()

        # D. Scheduler Step
        x_t = scheduler.step(noise_pred, t, x_t)
        
        # E. Re-inject known data (if Inpainting)
        x_t = mask * x_known + (1-mask) * x_t

    return x_t