[NNX] NNX migration (12/N): delete Linen code paths, classes, and NNX compatibility flags#4038
Draft
ecnal-cienet wants to merge 27 commits into
Draft
[NNX] NNX migration (12/N): delete Linen code paths, classes, and NNX compatibility flags#4038ecnal-cienet wants to merge 27 commits into
ecnal-cienet wants to merge 27 commits into
Conversation
7c6c8d3 to
5ca18e9
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
2037f3f to
6c8c56f
Compare
5dd6b93 to
c8cb446
Compare
…X default flip
Pre-flip safety: PR11 will flip pure_nnx/enable_nnx/pure_nnx_decoder from
False to True in base.yml. Some existing tests are Linen-coupled and would
either silently switch to NNX (and break) or silently SKIP after that flip.
Pin them to Linen explicitly so they keep exercising the Linen path, with
no behavior change today (the pin matches the current default).
tests/unit/tiling_test.py:
LossAndGradientCorrectnessTest builds models via transformer_as_linen and
exercises the Linen vocab_tiling path. Extend self.base_config in setUp
with enable_nnx=False, pure_nnx=False, pure_nnx_decoder=False, then drop
the 6 stale pytest.skip("We currently don't support vocab tiling on NNX
module.") guards (NNX-side coverage lives in VocabTilingNNXTest in the
same file, added in PR10).
tests/unit/pipeline_parallelism_test.py:
Pipeline parallelism does not yet have an NNX path (deferred to PR11.5).
Add _LINEN_PIN class const and append *self._LINEN_PIN to the 6
train_main arg lists in test_full_train_circular,
test_full_train_circular_pipeline_ag_per_repeat,
test_full_train_non_circular, test_subset_layers, test_full_train_fp8,
and test_full_train_nanoo_fp8. The unit-style
assert_pipeline_same_output_and_grad tests bypass the dispatch by
calling pipeline.create_pipeline + SimpleDecoderLayerToLinen directly,
so they are flag-immune and need no change.
The PR6-PR10 sequence promoted every routed-to-Linen feature to NNX-native (DPO/PR6, MaxEngine/PR7, LoRA+GRPO/PR8, QK-Clip + checkpoint utilities/PR9, AQT + serve-mode/PR9.5, vocab tiling custom_vjp/PR10). With those gaps closed, NNX is the production path; this commit makes it the default. Empirical break-test on CPU (pytest before/after the flip across tiling_test, train_compile_test, sharding_compare_test, maxtext_utils_test, maxengine_test) showed zero flip-induced failures - every CPU unit-test failure pre-existed on PR10 tip. TPU smoke verified end-to-end: gemma2-2b 3-step train under the new defaults logged "pure_nnx: True" in pyconfig and produced loss 13.04 -> 12.32 -> 11.82 (decreasing, no NaN/inf, no Traceback). Linen-only test files were already pinned in the prior commit so no per-test breakage from the flip. base.yml: enable_nnx, pure_nnx_decoder, pure_nnx all flip False -> True. No use_nnx_pipeline flag is added: PR10 tip has no NNX pipeline path to opt out of, so a one-valued flag would be dead weight. Pipeline tests keep their Linen pin from the prior commit; the eventual NNX pipeline work (PR11.5) will introduce its own opt-in if needed. Sharding goldens not regenerated: tests/unit/sharding_compare_test.py already pins enable_nnx=False, pure_nnx=False, pure_nnx_decoder=False explicitly when invoking the dump utility, so existing goldens at tests/utils/sharding_info/ stay valid against the flipped default.
…NX::test_nnx_model_dispatches_to_tree_map_with_path
1. Sanitize unmapped logical axes to None in maxtext_utils.py get_nnx_named_sharding_with_scan_axis to prevent compilation ValueError. 2. Fix qk_clip_utils.py broadcast shape mismatch (axis=0 to axis=-2) causing TypeError. 3. Update max_utils_test.py unscan utility to correctly parse TrainStateNNX and its parameters/sharding trees. 4. Fix muon_utils_test.py NNX dict mapping assertIsNone() against raw objects rather than . 5. Patch train_distill and train_sft to explicitly nnx.pop(Intermediate) to prevent GraphDef mutation ValueErrors. 6. Update diloco.py to use nnx.split instead of the deprecated filter() method for param extraction. 7. Update diloco_test.py to execute pure NNX training loop simulations instead of legacy Linen.
After flipping pure_nnx/enable_nnx/pure_nnx_decoder to True, several
integration tests broke because their code paths assumed Linen. Fixes:
- maxengine_test: remove the Linen-only test_basic_prefill / test_basic_decode
(they build the model with transformer_as_linen but the engine now expects
NNX state). The NNX path is already covered by test_basic_prefill_nnx /
test_basic_decode_nnx. Drop the now-unused imports and get_data helper.
- train_sft_deprecated: support the NNX train loop. Split the TrainStateNNX
into GraphDef + flat state before jit, only pass a dropout rng on the Linen
path (the NNX step takes (state, batch)), and read setup params via
nnx.split on the NNX path.
- quantizations.maybe_quantize_model: qwix.quantize_model traces NNX modules
and needs example inputs, so pass dummy decoder tokens/positions for the
NNX path. Fixes the fp8 sparsity smoke test.
- generate_param_only_checkpoint (NNX param-only flow):
- checkpointing._load_full_state_from_path: restore into a pure dict, since
NNX checkpoints are saved as pure dicts; a boxed nnx.State did not match.
- read opt_state from state.optimizer.opt_state on the NNX path.
- save only nnx.Param leaves (the rng PRNGKeyArray can't be cast to bf16)
and wrap each leaf as {"value": ...} so from_pretrained can read it back.
- skip the int8 case: it is a convert-on-load scenario (the fp32 training
checkpoint has no AqtDotGeneral state the int8 model expects); tracked as
a follow-up alongside layerwise_quantization.
…product test NNX int8 parameter-only generation requires a convert-on-load setup, which causes a ValueError since the fp32 training checkpoint lacks the AqtDotGeneral state that the target int8 model expects. This aligns the GPU/dot-product test with the existing skip in the TPU/autoselected test variant.
Linen Fp8DotGeneralBase.setup leaks intermediates inside an NNX context, so once NNX defaults flip to True (PR#11) the fp8 sparsity smoke and the fp8 GPU unit-test cases that go through Qwix/Linen quant break. Skip them until b/509790223 is fixed: - tests/integration/sparsity_test.py: fp8_full, fp8_full_with_sparsity - tests/unit/quantizations_test.py: test_fp8_gpu_quantization, test_fp8_nanoo_quantization
PR#11 flips the defaults to NNX, so the Linen reference engine in the prefill_multisampling/prefill_concat parity tests silently became NNX and crashed (device_put State-vs-dict), and test_stack_and_unstack_prefill_cache hit the NNX no-op branch. Drop the Linen comparisons and assert the NNX result shapes directly, rewrite the cache test for the NNX scan_layers=False path, and remove _build_linen_params and its imports.
PR #3929 moved src/maxtext/layers/train_state_nnx.py to src/maxtext/common/train_state_nnx.py. Update remaining imports in diloco.py and three test files so PR11 still imports correctly.
Under shard_optimizer_over_data, train_compile builds the AOT train-step input shardings by calling state_mesh_shardings.replace(params=params_shardings). That's a TrainState (flax.struct) method; with PR#11's NNX defaults, state_mesh_shardings is a flat nnx.State and the call dies with 'No attribute replace in State'. Add sharding.build_zero1_input_state_mesh_shardings that overlays params_shardings' Param leaves onto state_mesh_shardings.model for the NNX path while keeping the existing .replace behavior for Linen, and route both train_compile call sites through it. Fixes test_zero1_optimizer_sharding.
Under enable_diloco the state becomes a DiLoCoTrainState, but the pure_nnx path still merged it against the plain-model graphdef (nnx.merge leaf mismatch + segfault), and several downstream sites assumed a plain TrainStateNNX. Guard the merge and surface the graphdef as model; fix get_first_step, jit_model, params_shardings, setup_params, and the rng args in train_loop; match the diloco sharding's params to_pure_dict; and handle the DiLoCoTrainState in maybe_save_checkpoint by saving the synchronized global model. Train + checkpoint save/restore validated end-to-end on CPU.
The NNX decoder has no pipeline path yet, so under pure_nnx the scanned-layers axis is sharded by 'stage' and dies with a cryptic IndivisibleError at state init. Raise a clear NotImplementedError at config validation pointing users to ici_pipeline_parallelism=1 or the Linen path. NNX pipeline support is tracked as PR11.5.
…patch to NNX-only Across the core training/utils/inference/RL/checkpoint-conversion code, statically collapse every pure_nnx / enable_nnx / isinstance(model, nn.Module) branch to the NNX path (the model is always NNX now). No flag reads remain in these files.
…s_linen wrappers Delete TransformerLinenPure, the Linen Decoder/DecoderLayer/SequentialBlockDecoderLayers stack (decoders.py), and the dead *_as_linen ToLinen wrappers across the layer/model files. The wrapped NNX classes are unchanged; transformer_as_linen (the NNX->Linen bridge) is kept for the checkpoint-conversion tools.
Remove obsolete Linen-only tests, drop redundant flag args from the rest, and compile the hlo_diff tests via base.yml + model_name so they exercise the real NNX path.
…oder config flags Remove the three flags from types.py, base.yml, inference/vllm.yml, pyconfig, and the post-train distillation configs. NNX is the only path; the flags no longer exist.
c8cb446 to
40575ff
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
custom_vjpfor NNX (+ output-head carve-out).True. ([NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True #3526)pure_nnx/enable_nnx/pure_nnx_decoderflags. ([NNX] NNX migration (12/N): delete Linen code paths, classes, and NNX compatibility flags #4038)Description
PR11 made NNX the default and PR6–PR10 promoted every routed-to-Linen feature to NNX-native, so NNX is the only production path. This PR deletes Linen: all flag/
isinstancedispatch, the dead Linen classes and*_as_linenwrappers, the three compatibility flags, and the now-obsolete Linen-only tests.Net: 64 files, +728 / −6,205 (−5,477 lines). Organized into 4 reviewable commits (ordered so each is self-consistent — flags removed last, after every reference is gone):
pure_nnx/enable_nnx/isinstance(model, nn.Module)branch to the NNX path across core training / utils / inference / RL / checkpoint-conversion. Zero executable flag reads remain in src.TransformerLinenPure; the LinenDecoder/DecoderLayer/SequentialBlockDecoderLayersstack (decoders.py1525→47, onlydeepstack_processkept); and 28 dead*_as_linenToLinen wrappers across the layer/model files. The wrapped NNX classes are untouched.hlo_diff_test(see below).types.py,base.yml,inference/vllm.yml,pyconfig, and the post-train distillation configs.Stats
Diff: 64 files, +728 / −6,205 (net −5,477), in 4 commits — overwhelmingly deletion (it's a "remove dead Linen" PR). Production-vs-test split:
src/maxtext)*_as_linen)tests/)So ~75% of the line changes are production code and ~25% are tests (by deletions, 4,662 vs 1,543 ≈ 3:1). The largest single chunk is the Linen decoder stack +
*_as_linenwrappers (−2,746 in layers/models). Note this is pure refactor/removal — no new feature code; the "+728" insertions are almost entirely de-indenting kept NNX branches and rewriting a handful of tests.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.