From 164ca87b0f66e6d42bc02dc44a3a145ed7e269b4 Mon Sep 17 00:00:00 2001 From: csgoogle <198951627+csgoogle@users.noreply.github.com> Date: Sat, 13 Jun 2026 14:10:16 +0000 Subject: [PATCH] Add Stable Diffusion 1.5 support with data-parallel inference Add base15.yml for SD 1.5 (PyTorch weights via from_pt, PNDM/epsilon scheduler) and wire generate.py to it: - Build the sampler from the checkpoint's scheduler config via create_scheduler instead of a hardcoded DDIM scheduler, and iterate the full PNDM schedule (skip_prk_steps emits one extra timestep). - Shard the latent batch over the data axis with sharding constraints plus out_shardings so inference runs data parallel instead of replicating the whole batch on every device. Sub-device batches replicate. - Make override_scheduler_config tolerant of scheduler configs that omit keys (e.g. SD 1.5's PNDM config). --- .../base_stable_diffusion_checkpointer.py | 3 + src/maxdiffusion/configs/README.md | 7 + src/maxdiffusion/configs/base15.yml | 279 ++++++++++++++++++ src/maxdiffusion/generate.py | 59 +++- src/maxdiffusion/maxdiffusion_utils.py | 31 +- 5 files changed, 354 insertions(+), 25 deletions(-) create mode 100644 src/maxdiffusion/configs/base15.yml diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index baf5bdd67..26851b1ff 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -209,6 +209,7 @@ def load_diffusers_checkpoint(self): split_head_dim=self.config.split_head_dim, norm_num_groups=self.config.norm_num_groups, attention_kernel=self.config.attention, + flash_min_seq_length=getattr(self.config, "flash_min_seq_length", 4096), flash_block_sizes=flash_block_sizes, mesh=self.mesh, precision=precision, @@ -220,6 +221,7 @@ def load_diffusers_checkpoint(self): split_head_dim=self.config.split_head_dim, norm_num_groups=self.config.norm_num_groups, attention_kernel=self.config.attention, + flash_min_seq_length=getattr(self.config, "flash_min_seq_length", 4096), flash_block_sizes=flash_block_sizes, dtype=self.activations_dtype, weights_dtype=self.weights_dtype, @@ -279,6 +281,7 @@ def load_checkpoint(self, step=None, scheduler_class=None): split_head_dim=self.config.split_head_dim, norm_num_groups=self.config.norm_num_groups, attention_kernel=self.config.attention, + flash_min_seq_length=getattr(self.config, "flash_min_seq_length", 4096), flash_block_sizes=flash_block_sizes, mesh=self.mesh, precision=precision, diff --git a/src/maxdiffusion/configs/README.md b/src/maxdiffusion/configs/README.md index a052df291..d57d8b0c7 100644 --- a/src/maxdiffusion/configs/README.md +++ b/src/maxdiffusion/configs/README.md @@ -2,6 +2,13 @@ This directory contains model configuration for different Stable Diffusion models. +## Stable Diffusion 1.5 + +base15.yml - used for training and inference using [stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). +The upstream checkpoint ships PyTorch weights only, so this config sets `from_pt: True`; point +`pretrained_model_name_or_path` at a local diffusers snapshot for offline runs. It defaults to the +checkpoint's PNDM scheduler (epsilon prediction) to match the reference inference path. + ## Stable Diffusion 2.1 base21.yml - used for training and inference using [stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) diff --git a/src/maxdiffusion/configs/base15.yml b/src/maxdiffusion/configs/base15.yml new file mode 100644 index 000000000..ce7a25b98 --- /dev/null +++ b/src/maxdiffusion/configs/base15.yml @@ -0,0 +1,279 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stable Diffusion 1.5 base config. +# +# SD 1.5 shares the same architecture as SD 1.4 (CLIP ViT-L/14 text encoder, +# 860M UNet, AutoencoderKL) and only differs by the trained weights, so this +# config mirrors base14.yml and points at the v1-5 checkpoint. The upstream +# checkpoint only ships PyTorch weights, so from_pt is True by default; override +# pretrained_model_name_or_path to a local diffusers snapshot for offline runs. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: True + +# For testing, local file that stores function timing metrics such as state creation and compilation. +# If empty, no metrics are written. +timing_metrics_file: "" +write_timing_metrics: True + +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 10000000000 # Flushes Tensorboard + +pretrained_model_name_or_path: 'stable-diffusion-v1-5/stable-diffusion-v1-5' +unet_checkpoint: '' +# The canonical v1-5 repo only publishes the main (PyTorch) revision. +revision: 'main' + +# This will convert the weights to this dtype. +weights_dtype: 'float32' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch. The v1-5 checkpoint is PyTorch-only. +from_pt: True +split_head_dim: True +attention: 'tokamax_flash' # Supported attention: dot_product, flash, tokamax_flash +# Minimum Q/K/V sequence length required to use flash attention. For SD 1.5 +# 1024x1024 inference, the two largest self-attention lengths are 16384 and +# 4096, while cross-attention falls back to dot_product because text KV length +# is 77. +flash_min_seq_length: 4096 +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention +# while sequence is sharded in cross attention q. +attention_sharding_uniform: True +flash_block_sizes: { + "block_q" : 2048, + "block_kv_compute" : 1024, + "block_kv" : 2048, + "block_q_dkv" : 2048, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 1024 +} +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_unet: False + +# train text_encoder +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# SD 1.5 uses a PNDM sampler with epsilon prediction and leading timestep +# spacing. These mirror the checkpoint's scheduler_config.json so generation +# matches the diffusers/reference defaults. +diffusion_scheduler_config: { + _class_name: 'FlaxPNDMScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'leading' +} + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +base_output_directory: "" + +# Parallelism +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_sd15' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 512 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-7 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 800 +seed: 0 +# Output directory +# Create a GCS bucket, e.g. my-maxdiffusion-outputs and set this to "gs://my-maxdiffusion-outputs/" +output_dir: '' +per_device_batch_size: 1 + +warmup_steps_fraction: 0.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 1.e-2 # AdamW Weight decay +opt_enable_grad_clipping: False +max_grad_value: 1.0 +opt_enable_grad_global_norm_clipping: False +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 1 +profiler_steps: 5 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +guidance_scale: 7.5 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +# SD 1.5 reference inference default. +num_inference_steps: 20 + +enable_mllog: False + +# controlnet +controlnet_model_name_or_path: 'lllyasviel/sd-controlnet-canny' +controlnet_from_pt: True +controlnet_conditioning_scale: 1.0 +controlnet_image: 'https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg' + +# dreambooth - this script always uses prior preservation. +instance_data_dir: '' +class_data_dir: '' +instance_prompt: '' +class_prompt: '' +prior_loss_weight: 1.0 +num_class_images: 100 +# If true, set dataset_save_location. +cache_dreambooth_dataset: False +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +use_qwix_quantization: False +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + +# ML Diagnostics settings +enable_ml_diagnostics: False +profiler_gcs_path: "" +enable_ondemand_xprof: False diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index 72422ef09..c026e8d20 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -22,12 +22,13 @@ import jax from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding import jax.numpy as jnp from absl import app -from maxdiffusion import (pyconfig, FlaxDDIMScheduler, max_utils) +from maxdiffusion import (pyconfig, max_utils) from maxdiffusion.train_utils import transformer_engine_context -from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg +from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg, create_scheduler from flax.linen import partitioning as nn_partitioning from maxdiffusion.image_processor import VaeImageProcessor from maxdiffusion.trainers.stable_diffusion_trainer import (StableDiffusionTrainer) @@ -46,7 +47,17 @@ def post_training_steps(self, pipeline, params, train_states): return super().post_training_steps(pipeline, params, train_states) -def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidance_rescale): +def get_batch_sharding(mesh, config): + """Sharding for the batch dimension. + + Shard the batch over the data axis to run data parallel. For sub-device + batches (per_device_batch_size < 1) the batch can't be split, so replicate. + """ + spec = P() if config.per_device_batch_size < 1 else P("data") + return NamedSharding(mesh, spec) + + +def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidance_rescale, batch_sharding): latents, scheduler_state, state = args latents_input = jnp.concatenate([latents] * 2) @@ -76,6 +87,7 @@ def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidan ) latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + latents = jax.lax.with_sharding_constraint(latents, batch_sharding) return latents, scheduler_state, state @@ -86,9 +98,7 @@ def tokenize(prompt, tokenizer): ).input_ids -def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - +def get_unet_inputs(pipeline, params, states, config, rng, batch_sharding, batch_size): vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) prompt_ids = [config.prompt] * batch_size prompt_ids = tokenize(prompt_ids, pipeline.tokenizer) @@ -118,8 +128,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): ) latents = latents * params["scheduler"].init_noise_sigma - latents = jax.device_put(latents, data_sharding) - context = jax.device_put(context, data_sharding) + # Seed the batch sharding so it propagates through the program. + latents = jax.lax.with_sharding_constraint(latents, batch_sharding) + context = jax.lax.with_sharding_constraint(context, batch_sharding) return latents, context, guidance_scale, guidance_rescale, scheduler_state @@ -135,8 +146,10 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): unet_state = states["unet_state"] vae_state = states["vae_state"] + batch_sharding = get_batch_sharding(mesh, config) + (latents, context, guidance_scale, guidance_rescale, scheduler_state) = get_unet_inputs( - pipeline, params, states, config, rng, mesh, batch_size + pipeline, params, states, config, rng, batch_sharding, batch_size ) loop_body_p = functools.partial( @@ -146,18 +159,29 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): prompt_embeds=context, guidance_scale=guidance_scale, guidance_rescale=guidance_rescale, + batch_sharding=batch_sharding, ) vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, scheduler_state, unet_state)) + # Loop over the full sampler schedule. For most schedulers this equals + # num_inference_steps, but PNDM with skip_prk_steps emits one extra (PLMS + # warmup) timestep, so iterate over the actual scheduler timesteps to match + # diffusers/reference semantics exactly. + num_steps = scheduler_state.timesteps.shape[0] + with nn_partitioning.axis_rules(config.logical_axis_rules): + latents, _, _ = jax.lax.fori_loop(0, num_steps, loop_body_p, (latents, scheduler_state, unet_state)) image = vae_decode_p(latents, vae_state) return image def run(config): checkpoint_loader = GenerateSD(config, STABLE_DIFFUSION_CHECKPOINT) + with jax.set_mesh(checkpoint_loader.mesh): + return _run_with_mesh(config, checkpoint_loader) + + +def _run_with_mesh(config, checkpoint_loader): pipeline, params = checkpoint_loader.load_checkpoint() weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) @@ -221,12 +245,17 @@ def run(config): states["vae_state"] = vae_state states["text_encoder_state"] = text_encoder_state - scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained( - config.pretrained_model_name_or_path, revision=config.revision, subfolder="scheduler", dtype=jnp.float32 - ) + # Build the sampler from the checkpoint's scheduler config (PNDM for SD 1.5), + # honoring any overrides in config.diffusion_scheduler_config. This mirrors the + # SDXL generate path and keeps the scheduler choice driven by config rather + # than hardcoded here. + scheduler, scheduler_state = create_scheduler(pipeline.scheduler.config, config) pipeline.scheduler = scheduler params["scheduler"] = scheduler_state + # Keep the output sharding in line with the batch sharding. + image_out_sharding = get_batch_sharding(checkpoint_loader.mesh, config) + p_run_inference = jax.jit( functools.partial( run_inference, @@ -238,7 +267,7 @@ def run(config): batch_size=checkpoint_loader.total_train_batch_size, ), in_shardings=(state_shardings,), - out_shardings=None, + out_shardings=image_out_sharding, ) s = time.time() diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index c43813c37..ef152b15e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -147,18 +147,29 @@ def get_add_time_ids(original_size, crops_coords_top_left, target_size, bs, dtyp def override_scheduler_config(scheduler_config, config): - """Overrides diffusion scheduler params from checkpoint.""" + """Overrides diffusion scheduler params from config. - maxdiffusion_scheduler_config = config.diffusion_scheduler_config + Values set in ``config.diffusion_scheduler_config`` take precedence; empty + values fall back to the checkpoint's scheduler config, and finally to safe + defaults. Older diffusers scheduler configs (e.g. the SD 1.5 PNDM config) may + omit keys such as ``timestep_spacing`` or ``prediction_type``, so the + checkpoint lookups are tolerant of missing keys instead of raising. + """ - scheduler_config["_class_name"] = maxdiffusion_scheduler_config.get("_class_name", scheduler_config["_class_name"]) - scheduler_config["prediction_type"] = maxdiffusion_scheduler_config.get( - "prediction_type", scheduler_config["prediction_type"] - ) - scheduler_config["timestep_spacing"] = maxdiffusion_scheduler_config.get( - "timestep_spacing", scheduler_config["timestep_spacing"] - ) - scheduler_config["rescale_zero_terminal_snr"] = maxdiffusion_scheduler_config.get("rescale_zero_terminal_snr", False) + maxdiffusion_scheduler_config = getattr(config, "diffusion_scheduler_config", {}) or {} + + def _resolve(key, default): + # An explicit override wins; otherwise keep the checkpoint value, + # falling back to a safe default when the checkpoint omits the key. + override = maxdiffusion_scheduler_config.get(key, None) + if override is not None: + return override + return scheduler_config.get(key, default) + + scheduler_config["_class_name"] = _resolve("_class_name", scheduler_config.get("_class_name", "")) + scheduler_config["prediction_type"] = _resolve("prediction_type", "epsilon") + scheduler_config["timestep_spacing"] = _resolve("timestep_spacing", "leading") + scheduler_config["rescale_zero_terminal_snr"] = _resolve("rescale_zero_terminal_snr", False) return scheduler_config