Skip to content

[NVBug 6287315] Fix unified HF export for Llama4 MoE models#1744

Open
shengliangxu wants to merge 3 commits into
mainfrom
shengliangx/fix-llama4-moe
Open

[NVBug 6287315] Fix unified HF export for Llama4 MoE models#1744
shengliangxu wants to merge 3 commits into
mainfrom
shengliangx/fix-llama4-moe

Conversation

@shengliangxu

@shengliangxu shengliangxu commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

What does this PR do?

Type of change: Bug fix

Fixes unified HuggingFace checkpoint export for Llama4 MoE models (NVBug 6287315).

GptOssExperts and Llama4TextExperts are the two fused-expert model families that
get special handling throughout modelopt/torch/export/unified_export_hf.py, and they
appear together in every other special-cased path (e.g. the BMM-style weight
transposition at L626-629 and the uncalibrated-experts handling in
_process_quantized_modules at L796-798). The uncalibrated-experts input-quantizer
amax fallback inside _export_transformers_checkpoint, however, special-cased only
QuantGptOssExperts, so Llama4 MoE fell through and export failed.

Since both wrappers use the same fused gate_up_proj / down_proj layout with singular
input quantizers, QuantLlama4TextExperts is now handled by the same branch, restoring
Llama4 MoE export.

Usage

No API change. Quantizing and exporting a Llama4 MoE model now succeeds:

import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint

# model is a quantized Llama4 MoE (transformers Llama4 with Llama4TextExperts)
export_hf_checkpoint(model, export_dir="llama4_moe_quant")  # previously raised; now succeeds

Testing

Verified that unified HF export of a quantized Llama4 MoE checkpoint — which previously
failed per NVBug 6287315 — now completes successfully. The change extends an
already-special-cased branch that mirrors the GPT-OSS handling (same gate_up_proj /
down_proj fused layout), so behavior for all other model types is unchanged.

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ❌
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: N/A

Additional Information

  • NVBug 6287315.
  • Targets release 0.45.0; labeled cherry-pick-0.45.0 for backport to release/0.45.0.

Summary by CodeRabbit

  • Bug Fixes
    • Fixed checkpoint export for Llama4 Mixture of Experts models with uncalibrated expert quantization parameters.

GptOss and Llama4 Moe are 2 special handling models we have across the
file. They always appear in pairs in special handling code path, but
this problematic export path does not include Llama4 MoE. Adding it fix
the export failure.

Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
@shengliangxu shengliangxu requested review from a team as code owners June 16, 2026 01:17
@shengliangxu shengliangxu added the cherry-pick-0.45.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc label Jun 16, 2026
@coderabbitai

coderabbitai Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

The uncalibrated-experts input-quantizer amax fallback branch inside _export_transformers_checkpoint is extended to include QuantLlama4TextExperts alongside QuantGptOssExperts, routing both through the fused gate_up_proj/down_proj layout path. A changelog entry is added documenting the fix.

Changes

Llama4 MoE HF Export Fix

Layer / File(s) Summary
Fused-experts branch extension and changelog
modelopt/torch/export/unified_export_hf.py, CHANGELOG.rst
The isinstance guard in the MoE uncalibrated-experts handling is widened to match both QuantGptOssExperts and QuantLlama4TextExperts, deferring amax fallback and quantized weight export to _process_quantized_modules for both types. The changelog records the fix under 0.46 Bug Fixes.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and specifically addresses the main change: fixing unified HF export for Llama4 MoE models, with bug reference.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR contains no security anti-patterns from SECURITY.md: no unsafe torch.load, numpy.load, trust_remote_code, eval/exec, or nosec comments. Changes are benign type-name checks and comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch shengliangx/fix-llama4-moe

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_hf.py (1)

870-870: 💤 Low value

Consider renaming for clarity (optional).

The variable gpt_oss_linear_names now applies to both QuantGptOssExperts and QuantLlama4TextExperts. Consider renaming to fused_expert_linear_names for clarity.

♻️ Optional refactor
-                    gpt_oss_linear_names = ["gate_up_proj", "down_proj"]
-                    for linear_name in gpt_oss_linear_names:
+                    fused_expert_linear_names = ["gate_up_proj", "down_proj"]
+                    for linear_name in fused_expert_linear_names:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/export/unified_export_hf.py` at line 870, The variable
`gpt_oss_linear_names` at line 870 is misleading because it is now used for both
`QuantGptOssExperts` and `QuantLlama4TextExperts` model types, not exclusively
for GPT-OSS models. Rename the variable from `gpt_oss_linear_names` to
`fused_expert_linear_names` throughout the code to better reflect its broader
applicability to fused expert layer types, ensuring all references to this
variable are updated consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Line 870: The variable `gpt_oss_linear_names` at line 870 is misleading
because it is now used for both `QuantGptOssExperts` and
`QuantLlama4TextExperts` model types, not exclusively for GPT-OSS models. Rename
the variable from `gpt_oss_linear_names` to `fused_expert_linear_names`
throughout the code to better reflect its broader applicability to fused expert
layer types, ensuring all references to this variable are updated consistently.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ebe35685-23b4-4b2c-8553-ef1f0bf471f4

📥 Commits

Reviewing files that changed from the base of the PR and between e6790ef and 11d8eb2.

📒 Files selected for processing (2)
  • CHANGELOG.rst
  • modelopt/torch/export/unified_export_hf.py

@cjluo-nv cjluo-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Small, correct bug fix (+9/-3, 2 files): the uncalibrated-experts input-quantizer amax fallback in _export_transformers_checkpoint previously matched only QuantGptOssExperts, so Llama4 MoE export failed. The new branch also matches QuantLlama4TextExperts.

Verified against the codebase:

  • _QuantLlama4TextExperts (registered for Llama4TextExperts) defines exactly the same singular gate_up_proj_input_quantizer / down_proj_input_quantizer / gate_up_proj+down_proj fused layout as _QuantGptOssExperts, so routing it through the same branch is correct.
  • _process_quantized_modules already handles both Llama4TextExperts and GptOssExperts together (amax fallback + weight export), matching the new branch's comment.
  • Branch ordering is safe: _QuantLlama4TextExperts uses the singular gate_up_proj_weight_quantizer, so it does not get caught by the earlier _QuantFusedExperts gate_up_proj_weight_quantizers (plural) elif and correctly falls through.

Licensing clean (standard NVIDIA header on existing file, CHANGELOG entry only). No design-review concerns (additive bug fix). No prompt-injection issues in the PR content.

Flagging for a human look only because there is no automated test for the fixed path — the author states a full Llama4 MoE checkpoint is needed and the branch parallels the already-covered GPT-OSS path, which is reasonable but worth an owner's sign-off.

@github-actions

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1744/

Built to branch gh-pages at 2026-06-16 01:22 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.55%. Comparing base (e6790ef) to head (310e149).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1744      +/-   ##
==========================================
- Coverage   77.12%   76.55%   -0.58%     
==========================================
  Files         511      511              
  Lines       56273    56273              
==========================================
- Hits        43399    43077     -322     
- Misses      12874    13196     +322     
Flag Coverage Δ
examples 41.84% <0.00%> (-0.13%) ⬇️
gpu 57.76% <100.00%> (-0.61%) ⬇️
regression 14.70% <0.00%> (+0.06%) ⬆️
unit 54.39% <0.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-0.45.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants