Skip to content

refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417

Open
prishajain1 wants to merge 1 commit into
mainfrom
prisha/flux_training
Open

refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417
prishajain1 wants to merge 1 commit into
mainfrom
prisha/flux_training

Conversation

@prishajain1

Copy link
Copy Markdown
Collaborator

Overview

This PR refactors the Flux model architecture in MaxDiffusion to support scanned blocks (nn.scan) for double and single blocks, implements configurable gradient checkpointing (rematerialization) policies, and updates the weights loader to support loading pretrained checkpoints under the scanned format.

Key Changes

  • Decoupled Fused Projections: Decoupled the projection layers (implementing the MlpAndOutputBlock wrapper) to eliminate redundant recomputation of attention and projection outputs.
  • QKV Slicing Refactoring: Refactored the QKV projection slicing logic to use jnp.split across Flux transformer blocks for cleaner layout constraints.
  • Scanned Block Architecture: Migrated Flux Double and Single Transformer Blocks to use nn.scan to optimize compiler tracing and step execution speed on TPUs.
  • Dynamic Gradient Checkpointing: Added FLUX_OPTIMIZED to GradientCheckpointType to allow configuring block-specific rematerialization policies dynamically via configuration files instead of being hardcoded.
  • Stacked Weights Loading: Updated the weights loader (util.py) to slice, group, and stack PyTorch checkpoint weights along axis 0 to match the expected format of nn.scan layers.

@prishajain1 prishajain1 requested a review from entrpn as a code owner June 12, 2026 06:20
@github-actions

Copy link
Copy Markdown

@prishajain1 prishajain1 marked this pull request as draft June 12, 2026 06:20
@prishajain1 prishajain1 force-pushed the prisha/flux_training branch 2 times, most recently from 4696256 to 11ddfef Compare June 12, 2026 06:29
@prishajain1 prishajain1 marked this pull request as ready for review June 12, 2026 06:31
@github-actions

Copy link
Copy Markdown

🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @prishajain1, but I was unable to process your request. Please see the logs for more details.

Comment thread src/maxdiffusion/checkpointing/flux_checkpointer.py
Comment thread src/maxdiffusion/checkpointing/flux_checkpointer.py
Comment thread src/maxdiffusion/configs/base_flux_dev.yml
Comment thread src/maxdiffusion/models/flux/transformers/transformer_flux.py
Comment thread src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py
Comment thread src/maxdiffusion/models/flux/transformers/transformer_flux.py Outdated
Comment thread src/maxdiffusion/models/normalization_flax.py
Comment thread src/maxdiffusion/models/normalization_flax.py
Comment thread src/maxdiffusion/models/normalization_flax.py
Comment thread src/maxdiffusion/generate_flux.py
@prishajain1 prishajain1 force-pushed the prisha/flux_training branch 4 times, most recently from 8c8dcec to f58fb9e Compare June 13, 2026 13:39
@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants