Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.flux.transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
Expand Down Expand Up @@ -444,7 +444,7 @@
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.flux.transformers.transformer_flux import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
Expand Down
12 changes: 11 additions & 1 deletion src/maxdiffusion/checkpointing/flux_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
FlaxAutoencoderKL,
max_logging,
)
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.models.flux.transformers.transformer_flux import FluxTransformer2DModel
from ..pipelines.flux.flux_pipeline import FluxPipeline

from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer)
Expand Down Expand Up @@ -214,6 +214,11 @@ def load_diffusers_checkpoint(self):
dtype=self.config.activations_dtype,
weights_dtype=self.config.weights_dtype,
precision=max_utils.get_precision(self.config),
use_base2_exp=self.config.use_base2_exp,
use_experimental_scheduler=self.config.use_experimental_scheduler,
remat_policy=self.config.remat_policy,
names_which_can_be_saved=self.config.names_which_can_be_saved,
names_which_can_be_offloaded=self.config.names_which_can_be_offloaded,
Comment thread
prishajain1 marked this conversation as resolved.
)
transformer_eval_params = transformer.init_weights(
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
Expand Down Expand Up @@ -279,6 +284,11 @@ def load_checkpoint(self, step=None, scheduler_class=None):
weights_dtype=self.config.weights_dtype,
precision=max_utils.get_precision(self.config),
from_pt=self.config.from_pt,
use_base2_exp=self.config.use_base2_exp,
use_experimental_scheduler=self.config.use_experimental_scheduler,
remat_policy=self.config.remat_policy,
names_which_can_be_saved=self.config.names_which_can_be_saved,
names_which_can_be_offloaded=self.config.names_which_can_be_offloaded,
)
Comment thread
prishajain1 marked this conversation as resolved.

pipeline = FluxPipeline(
Expand Down
47 changes: 33 additions & 14 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
use_base2_exp: False
use_experimental_scheduler: False
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
Expand All @@ -73,18 +75,18 @@ mask_padding_tokens: True
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
#flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
# flash_block_sizes: {
# "block_q" : 1536,
# "block_kv_compute" : 1536,
# "block_kv" : 1536,
# "block_q_dkv" : 1536,
# "block_kv_dkv" : 1536,
# "block_kv_dkv_compute" : 1536,
# "block_q_dq" : 1536,
# "block_kv_dq" : 1536
# }
flash_block_sizes: {
"block_q" : 1536,
"block_kv_compute" : 1536,
"block_kv" : 1536,
"block_q_dkv" : 1536,
"block_kv_dkv" : 1536,
"block_kv_dkv_compute" : 1536,
"block_q_dq" : 1536,
"block_kv_dq" : 1536
}
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -147,9 +149,11 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['batch', ['data','fsdp']],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_heads', 'fsdp'],
['activation_length', 'context'],
['activation_kv_length', 'context'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
Expand Down Expand Up @@ -188,7 +192,7 @@ dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
# 3. Optionally override dimensions
#
# synthetic_num_samples: null # null for infinite, or set a number
synthetic_num_samples: 1000 # null for infinite, or set a number
#
# Optional dimension overrides:
# resolution: 512
Expand Down Expand Up @@ -218,6 +222,21 @@ transform_images_num_proc: 4
reuse_example_batch: False
enable_data_shuffling: True

# Defines the type of gradient checkpoint to enable.
# NONE - means no gradient checkpoint
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
# except for ones that involve batch dimension - that means that all attention and projection
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
# CUSTOM - set names to offload and save.
remat_policy: "FLUX_OPTIMIZED"
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
# xq_out, xk_out, ffn_activation
names_which_can_be_saved: []
names_which_can_be_offloaded: []
Comment thread
prishajain1 marked this conversation as resolved.
flash_min_seq_length: 0

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)

from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging, max_utils
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.models.flux.transformers.transformer_flux import FluxTransformer2DModel
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.max_utils import (
device_put_replicated,
Expand Down Expand Up @@ -314,6 +314,9 @@ def run(config):
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
precision=get_precision(config),
remat_policy=config.remat_policy,
names_which_can_be_saved=config.names_which_can_be_saved,
names_which_can_be_offloaded=config.names_which_can_be_offloaded,
Comment thread
prishajain1 marked this conversation as resolved.
)

num_channels_latents = transformer.in_channels // 4
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/generate_flux_multi_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)

from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging, max_utils
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.models.flux.transformers.transformer_flux import FluxTransformer2DModel
from maxdiffusion.max_utils import (
device_put_replicated,
get_memory_allocations,
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .flux.transformers.transformer_flux import FluxTransformer2DModel
from .ltx_video.transformers.transformer3d import Transformer3DModel

else:
Expand Down
64 changes: 43 additions & 21 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,8 @@ class FlaxFluxAttention(nn.Module):
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
precision: jax.lax.Precision = None
qkv_bias: bool = False
use_base2_exp: bool = False
use_experimental_scheduler: bool = False

def setup(self):
if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None:
Expand All @@ -1843,6 +1845,8 @@ def setup(self):
flash_block_sizes=self.flash_block_sizes,
dtype=self.dtype,
float32_qk_product=False,
use_base2_exp=self.use_base2_exp,
use_experimental_scheduler=self.use_experimental_scheduler,
)

kernel_axes = ("embed", "heads")
Expand Down Expand Up @@ -1923,41 +1927,59 @@ def __call__(
attention_mask=None,
image_rotary_emb=None,
):
qkv_proj = self.qkv(hidden_states)
B, L = hidden_states.shape[:2]
H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3
qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
query_proj, key_proj, value_proj = qkv_proj
# Deduce dimensions cleanly from class attributes
H, D = self.heads, self.dim_head

query_proj = self.query_norm(query_proj)
qkv_proj = self.qkv(hidden_states)
qkv_proj = checkpoint_name(qkv_proj, "img_qkv_proj")

qkv_proj = qkv_proj.reshape(B, L, 3, H, D)
query_proj, key_proj, value_proj = jnp.split(qkv_proj, 3, axis=2)
query_proj = query_proj.squeeze(2)
key_proj = key_proj.squeeze(2)
value_proj = value_proj.squeeze(2)

query_proj = self.query_norm(query_proj)
key_proj = self.key_norm(key_proj)

if encoder_hidden_states is not None:
B_enc, L_txt = encoder_hidden_states.shape[:2]
encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states)
B, L = encoder_hidden_states.shape[:2]
H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3
encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj
encoder_qkv_proj = checkpoint_name(encoder_qkv_proj, "txt_qkv_proj")
encoder_qkv_proj = encoder_qkv_proj.reshape(B_enc, L_txt, 3, H, D)
enc_query_proj, enc_key_proj, enc_value_proj = jnp.split(encoder_qkv_proj, 3, axis=2)
enc_query_proj = enc_query_proj.squeeze(2)
enc_key_proj = enc_key_proj.squeeze(2)
enc_value_proj = enc_value_proj.squeeze(2)

encoder_query_proj = self.encoder_query_norm(encoder_query_proj)
encoder_query_proj = self.encoder_query_norm(enc_query_proj)
encoder_key_proj = self.encoder_key_norm(enc_key_proj)

encoder_key_proj = self.encoder_key_norm(encoder_key_proj)
query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=1)
key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=1)
value_proj = jnp.concatenate((enc_value_proj, value_proj), axis=1)

query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2)
key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2)
value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2)

query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
# query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
# key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
# value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)

image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)

query_proj = query_proj.swapaxes(1, 2)
key_proj = key_proj.swapaxes(1, 2)
query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb)
query_proj = query_proj.swapaxes(1, 2)
key_proj = key_proj.swapaxes(1, 2)

query_proj = query_proj.reshape(B, -1, H * D)
key_proj = key_proj.reshape(B, -1, H * D)
value_proj = value_proj.reshape(B, -1, H * D)

query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1)
key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1)
value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1)
if encoder_hidden_states is not None:
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)

attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=attention_mask)
context_attn_output = None
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
limitations under the License.
"""

from .transformers.transformer_flux_flax import FluxTransformer2DModel
from .transformers.transformer_flux import FluxTransformer2DModel
Loading
Loading