Skip to content

Add Stable Diffusion 1.5 support with data-parallel inference#418

Open
csgoogle wants to merge 1 commit into
mainfrom
add-sd15-support
Open

Add Stable Diffusion 1.5 support with data-parallel inference#418
csgoogle wants to merge 1 commit into
mainfrom
add-sd15-support

Conversation

@csgoogle

@csgoogle csgoogle commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds Stable Diffusion 1.5 to MaxDiffusion and makes its inference data parallel.

  • configs/base15.yml (new) — SD 1.5 config. Mirrors base14.yml (same architecture, different weights), points at the stable-diffusion-v1-5 checkpoint, sets from_pt: True (upstream ships PyTorch weights only), and defaults to the checkpoint's PNDM / epsilon scheduler.
  • generate.py
    • Build the sampler from the checkpoint's scheduler config via create_scheduler (config-driven, mirrors the SDXL path) instead of a hardcoded FlaxDDIMScheduler, and iterate the full PNDM schedule (skip_prk_steps emits one extra timestep).
    • Shard the latent batch over the data axis using with_sharding_constraint + a batch-sharded out_shardings, so GSPMD propagates data-parallelism through the whole UNet/VAE instead of replicating the entire batch on every device. A single get_batch_sharding helper is the source of truth and replicates for sub-device batches (per_device_batch_size < 1).
    • override_scheduler_config is now tolerant of scheduler configs that omit keys (e.g. SD 1.5's older PNDM config).

Why

The previous generate path declared a data mesh axis but the program ran fully replicated (device_put inside the jit is a weak hint GSPMD ignored), so each chip recomputed the whole batch. Forcing the batch sharding makes inference genuinely data parallel.

Performance

TPU7x (8 chips), SD 1.5 1024px, 20-step PNDM, 8 images: 0.130s . Scales cleanly to larger batches.

Test plan

  • Generates coherent images matching the prompts.

Prompt: "A cinematic photo of a glass greenhouse on a snowy mountain"
image
image
image

@csgoogle csgoogle requested a review from entrpn as a code owner June 13, 2026 14:10
@github-actions

Copy link
Copy Markdown

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully adds support for Stable Diffusion 1.5 to MaxDiffusion and implements genuine data-parallel inference by sharding latents and context along the data mesh axis. The changes are elegant, well-structured, and provide significant performance improvements by avoiding redundant recomputation of batches across devices.

🔍 General Feedback

  • Excellent Alignment with Checkpoint Semantics: Iterating over the actual timesteps shape in the loop instead of hardcoded steps ensures perfect alignment with schedulers like PNDM which emit an extra timestep when skip_prk_steps is enabled.
  • Robust Parallelism: Forcing genuine batch-sharding constraints is a clean and robust approach to propagation of data parallelism across UNet and VAE layers.
  • Configuration-driven Design: Transitioning scheduler instantiation to use config-driven create_scheduler mirrors the SDXL pipeline beautifully and increases code reuse.
  • Robustness Improvement: An issue was identified in the config-override fallback mechanism where explicit falsy overrides (such as False) were ignored, and checkpoint defaults were bypassed. An inline code suggestion has been provided to resolve this.

Comment thread src/maxdiffusion/maxdiffusion_utils.py Outdated
@csgoogle csgoogle force-pushed the add-sd15-support branch 2 times, most recently from 5de6089 to 308e742 Compare June 13, 2026 16:24
Add base15.yml for SD 1.5 (PyTorch weights via from_pt, PNDM/epsilon
scheduler) and wire generate.py to it:

- Build the sampler from the checkpoint's scheduler config via
  create_scheduler instead of a hardcoded DDIM scheduler, and iterate the
  full PNDM schedule (skip_prk_steps emits one extra timestep).
- Shard the latent batch over the data axis with sharding constraints plus
  out_shardings so inference runs data parallel instead of replicating the
  whole batch on every device. Sub-device batches replicate.
- Make override_scheduler_config tolerant of scheduler configs that omit
  keys (e.g. SD 1.5's PNDM config).
@github-actions

Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @csgoogle, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants