Skip to content

[DRAFT] Add heterogeneous AnyModel distillation example for Puzzletron.#1725

Open
chochowski wants to merge 1 commit into
mainfrom
1679-add-mbridge-dsitillation-for-puzzletron
Open

[DRAFT] Add heterogeneous AnyModel distillation example for Puzzletron.#1725
chochowski wants to merge 1 commit into
mainfrom
1679-add-mbridge-dsitillation-for-puzzletron

Conversation

@chochowski

@chochowski chochowski commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Ports examples/puzzletron/distillation from the puzzletron_distillation branch with SPDX headers, nemo2 registry fix, and puzzletron README link.

What does this PR do?

This PR adds support for mbridge distillation of puzzled heterogeneous models

Usage

In general distillation requires more compute than normal training, thus ideally you should run this on a multinode-setting (unless very constrained scenarion)

torchrun --nproc-per-node=8  examples/puzzletron/distillation/distill.py \
--student llama \
--teacher llama \
--student-checkpoint /puzzletron/workspaces/Llama-3.1-8B-Instruct/mip/puzzle_solutions/target_memory_78000MiB-num_params_7G/solutions--checkpoints/solution_0/ \
--teacher-checkpoint /puzzletron/workspaces/Llama-3.1-8B-Instruct/ckpts/teacher/   \
--config-file examples/puzzletron/distillation/kd-container-llama.yaml \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 8 \
--expert-model-parallel-size 1 \
--expert-tensor-parallel-size 1 \
train.train_iters=1000 \
checkpoint.save=/puzzletron/workspaces/Llama-3.1-8B-Instruct/kd/puzzle_solutions/target_memory_78000MiB-num_params_7G-intermediate-fix/ \
logger.wandb_exp_name=Llama-3.1-8B-Instruct-target_memory_78000MiB-num_params_7G-intermediate-fix-verify

Testing

TBD

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added knowledge distillation example supporting heterogeneous model pairs via Megatron Bridge.
    • Added script to export trained Megatron checkpoints back to HuggingFace format.
    • Added support for GPT-OSS model distillation.
  • Documentation

    • Added comprehensive distillation guide with setup, configuration, and usage examples.

Ports examples/puzzletron/distillation from the puzzletron_distillation
branch with SPDX headers, nemo2 registry fix, and puzzletron README link.

Signed-off-by: mchochowski <mchochowski@nvidia.com>
@chochowski chochowski requested review from a team as code owners June 15, 2026 09:27
@chochowski chochowski linked an issue Jun 15, 2026 that may be closed by this pull request
@coderabbitai

coderabbitai Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a new examples/puzzletron/distillation/ package implementing heterogeneous AnyModel/Puzzletron knowledge distillation via Megatron Bridge. It introduces per-layer block config translation to MCore overrides, a mbridge_patcher context manager for monkey-patching layer construction, bridge and provider patches, a GPTOSSBridge with MXFP4 dequantization, distill.py and export_to_hf.py entrypoints, a default YAML config, and accompanying documentation.

Changes

Heterogeneous Puzzletron distillation example

Layer / File(s) Summary
Per-layer block config data contract and MCoreLayerOverrides
examples/puzzletron/distillation/block_config_utils.py
Defines MCoreLayerOverrides dataclass and all functions to load, normalize, and translate per-decoder-layer block_configs into Megatron-Core TransformerConfig override structures with attention/MLP no-op flags, supporting GQA, Mamba slots, dense MLP, and MoE FFN.
MCore layer patching and no-op submodule replacement
examples/puzzletron/distillation/layer_patchers.py
Introduces NoOpWithBias, NoOpRMSNorm, thread-local patcher state, _NO_OP_RULES, _apply_no_ops, and the mbridge_patcher context manager that monkey-patches build_module in multiple Megatron namespaces and MambaLayer.__init__ to apply per-layer config overrides and replace disabled subblocks during model construction.
Model bridge and provider monkey-patches
examples/puzzletron/distillation/model_bridge_patch.py, examples/puzzletron/distillation/provider_patch.py
Patches MegatronModelBridge.load_weights_hf_to_megatron with a heterogeneous MoE name resolver and post-load diagnostic, and patches ModelProviderMixin.provide / DistillationProvider.provide to wrap provider construction inside mbridge_patcher; includes set_provider_block_configs and set_student_block_configs helpers.
GPT-OSS Megatron bridge and MXFP4 dequantization
examples/puzzletron/distillation/gpt_oss_bridge.py
Registers GPTOSSBridge converting GptOssForCausalLM to Megatron GPTModel with per-expert weight re-indexing, MXFP4 block+scale dequantization via _dequantize_mxfp4, and custom AutoMapping subclasses for MoE MLP expert down-projection and gate/up projection with interleaving.
Shared helpers and MODEL_REGISTRY
examples/puzzletron/distillation/_common.py
Defines MODEL_REGISTRY, path constants, logging setup, HF config/block config/descriptor loading helpers, _load_bridge via deci_x_patcher, _build_provider, and the run_entrypoint wrapper with distributed teardown.
Distillation entrypoint and default YAML config
examples/puzzletron/distillation/distill.py, examples/puzzletron/distillation/kd-container-default.yaml
Implements main() orchestrating provider setup, ConfigContainer construction, YAML/Hydra override merging, teacher config sync, _install_hybrid_moe_aux_loss_size_fix (monkey-patching track_moe_metrics to prevent NCCL deadlocks on zero-MoE PP stages), and distill() invocation. The default YAML provides training, optimizer, dataset, checkpoint, logging, and runtime defaults.
Export-to-HF entrypoint
examples/puzzletron/distillation/export_to_hf.py
Implements main() for exporting an MCore checkpoint back to HuggingFace format: validates registry key, loads student HF config and block configs, applies provider patches, exports via student_bridge.export_ckpt() inside mbridge_patcher, and copies tokenizer/config files into the output directory.
Public API, docs, and project config
examples/puzzletron/distillation/__init__.py, examples/puzzletron/distillation/README.md, examples/puzzletron/README.md, pyproject.toml
Adds the __init__.py re-exporting the public API, a full distillation/README.md covering setup/usage/export/debugging, a pointer in the parent README, and a Ruff per-file-ignores entry.

Sequence Diagram(s)

sequenceDiagram
    participant CLI
    participant distill_main as distill.py main()
    participant _common
    participant provider_patch
    participant DistillationProvider
    participant mbridge_patcher
    participant MCore_build_module as MCore build_module
    participant Bridge_distill as Bridge.distill()

    CLI->>distill_main: student/teacher keys, checkpoint paths, YAML + CLI overrides
    distill_main->>_common: _load_hf_config → _get_block_configs (student + teacher)
    distill_main->>provider_patch: apply_patch() + apply_distillation_patch()
    distill_main->>_common: _load_bridge + _build_provider (student + teacher)
    distill_main->>DistillationProvider: set_student_block_configs(student_block_configs)
    distill_main->>distill_main: _build_distill_config_container → merge YAML + CLI overrides
    distill_main->>distill_main: _sync_teacher_config_from_student
    distill_main->>distill_main: _install_hybrid_moe_aux_loss_size_fix (patch track_moe_metrics)
    distill_main->>Bridge_distill: distill(config)
    Bridge_distill->>DistillationProvider: provide(config)
    DistillationProvider->>mbridge_patcher: enter(student_block_configs, num_heads, hidden_size)
    mbridge_patcher->>MCore_build_module: patched — apply MCoreLayerOverrides + NoOpWithBias per layer
    MCore_build_module-->>mbridge_patcher: layer built with overrides
    mbridge_patcher-->>DistillationProvider: exit, restore build_module
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.25% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding a heterogeneous AnyModel distillation example for Puzzletron, which is the core purpose of this PR based on the raw_summary and PR objectives.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All critical security anti-patterns from SECURITY.md are absent: no torch.load(weights_only=False), no numpy.load(allow_pickle=True), no hardcoded trust_remote_code=True (only user-configurable via...

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch 1679-add-mbridge-dsitillation-for-puzzletron

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 75.82%. Comparing base (cc17f2c) to head (48cb89a).
⚠️ Report is 10 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1725      +/-   ##
==========================================
- Coverage   77.09%   75.82%   -1.27%     
==========================================
  Files         511      511              
  Lines       56168    57001     +833     
==========================================
- Hits        43302    43223      -79     
- Misses      12866    13778     +912     
Flag Coverage Δ
examples 40.59% <ø> (-1.36%) ⬇️
gpu 57.70% <ø> (-0.62%) ⬇️
regression 14.66% <ø> (+0.06%) ⬆️
unit 54.40% <ø> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 7

🧹 Nitpick comments (6)
examples/puzzletron/distillation/export_to_hf.py (1)

112-113: ⚡ Quick win

Move deferred imports to module scope, or explicitly justify them inline.

Imports inside main() currently have no documented reason (circular dependency, optional dependency, or heavy import deferral), so this deviates from the repository import rule.

As per coding guidelines, imports should stay at module top unless delayed import is necessary and explicitly justified.

Also applies to: 135-135

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/export_to_hf.py` around lines 112 - 113,
Move the deferred imports out of the main() function and place them at the
module scope with other imports at the top of the file. The import statement
`from provider_patch import apply_distillation_patch, apply_patch` at line
112-113 and the other deferred import at line 135 should both be relocated to
the module-level import section unless there is a documented reason (such as
circular dependency, optional dependency, or heavy import deferral) for keeping
them deferred—in which case, add an explicit inline comment justifying the
deferred import at that location.

Source: Coding guidelines

examples/puzzletron/distillation/block_config_utils.py (1)

42-52: 💤 Low value

Consider adding __all__ to declare the public API.

Per coding guidelines, each module should declare its public surface with __all__ at the top. This module exports MCoreLayerOverrides, load_block_configs, block_config_to_mcore_overrides, and get_overrides_for_layer.

Suggested addition after imports
__all__ = [
    "MCoreLayerOverrides",
    "load_block_configs",
    "block_config_to_mcore_overrides",
    "get_overrides_for_layer",
]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/block_config_utils.py` around lines 42 - 52,
The module is missing an `__all__` declaration to explicitly define its public
API. Add an `__all__` list after the import statements (after the logger
initialization) that includes the four public symbols: MCoreLayerOverrides,
load_block_configs, block_config_to_mcore_overrides, and
get_overrides_for_layer. This makes the module's public interface explicit and
aligns with coding guidelines.

Source: Coding guidelines

examples/puzzletron/distillation/layer_patchers.py (1)

43-59: 💤 Low value

Consider adding __all__ to declare the public API.

This module exports NoOpWithBias, NoOpRMSNorm, and mbridge_patcher as public symbols. Per coding guidelines, each module should declare its public surface with __all__.

Suggested addition after imports
__all__ = [
    "NoOpWithBias",
    "NoOpRMSNorm",
    "mbridge_patcher",
]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/layer_patchers.py` around lines 43 - 59, The
module is missing an `__all__` declaration to explicitly define its public API.
After the logger initialization line where `logger =
logging.getLogger(__name__)` is defined, add an `__all__` list that includes the
three exported symbols: NoOpWithBias, NoOpRMSNorm, and mbridge_patcher. This
declaration clarifies the module's public surface and follows coding guidelines
for proper API documentation.

Source: Coding guidelines

examples/puzzletron/distillation/distill.py (2)

616-623: 💤 Low value

Redundant distributed teardown.

The torch.distributed.destroy_process_group() is already handled in run_entrypoint() (line 176-177 in _common.py). This duplicate call is safe (it checks is_initialized()) but unnecessary.

♻️ Proposed fix
     try:
         distill(config)
     except Exception as e:
         logger.error("Error during distillation: %s", e)
         raise e
-    finally:
-        if torch.distributed.is_initialized():
-            torch.distributed.destroy_process_group()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/distill.py` around lines 616 - 623, The
finally block in the distill function that calls
torch.distributed.destroy_process_group() is redundant because this cleanup is
already handled by run_entrypoint() in _common.py. Remove the entire finally
block (which includes the is_initialized() check and destroy_process_group()
call) from the distill function, as the distributed process group teardown will
still be properly managed at the higher level in run_entrypoint().

151-152: 💤 Low value

Star imports pollute the module namespace.

The star imports from modelopt.torch.puzzletron.anymodel.converter and model_descriptor can pollute the namespace and make it unclear which names are used. Consider importing only the specific symbols needed.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/distill.py` around lines 151 - 152, The star
imports from modelopt.torch.puzzletron.anymodel.converter and
modelopt.torch.puzzletron.anymodel.model_descriptor pollute the namespace and
obscure which symbols are actually used. Replace both star import statements
with explicit imports of only the specific symbols needed from each module.
Identify which classes, functions, or constants from converter and
model_descriptor are actually used in this file and import them by name instead
of using the wildcard syntax.
examples/puzzletron/distillation/gpt_oss_bridge.py (1)

188-202: 💤 Low value

Potential variable shadowing with num_experts and ep_size.

Lines 188-190 calculate num_experts from self.hf_config.num_local_experts, but lines 199-202 immediately override these values from block_config. The initial calculation at lines 188-190 is never used.

Consider removing the dead code:

♻️ Proposed fix
     def maybe_modify_converted_hf_weight(
         self,
         task: WeightConversionTask,
         converted_weights_dict: dict[str, torch.Tensor],
         hf_state_dict: Mapping[str, torch.Tensor],
     ) -> dict[str, torch.Tensor]:
-        num_experts = self.hf_config.num_local_experts
-        ep_size = parallel_state.get_expert_model_parallel_world_size()
-        experts_per_rank = num_experts // ep_size
-
         try:
             layer_idx = extract_layer_idx_from_param(task.param_name)
             expert_number = extract_expert_number_from_param(task.param_name)
         except ValueError:
             # not an expert weight
             return converted_weights_dict

         block_config = self.hf_config.block_configs[layer_idx]
         num_experts = block_config["ffn"]["moe"]["num_local_experts"]
         ep_size = parallel_state.get_expert_model_parallel_world_size()
         experts_per_rank = num_experts // ep_size
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/gpt_oss_bridge.py` around lines 188 - 202,
Remove the dead code that calculates num_experts, ep_size, and experts_per_rank
from self.hf_config.num_local_experts before the try-except block. These
variables are shadowed and immediately recalculated after the
extract_layer_idx_from_param and extract_expert_number_from_param calls using
block_config instead, making the initial calculations at lines 188-190 unused.
Delete the redundant variable assignments and keep only the recalculation that
uses block_config["ffn"]["moe"]["num_local_experts"].
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/puzzletron/distillation/__init__.py`:
- Around line 36-55: The import statements at lines 36-43 use absolute imports
(block_config_utils, layer_patchers, provider_patch) which can break when
importing the distillation package as a whole. Convert these to package-relative
imports by prepending a dot (.) before each module name (from
.block_config_utils import ..., from .layer_patchers import ..., from
.provider_patch import ...). This ensures the imports are resolved relative to
the current package and follows the repository __init__.py API pattern using
relative star re-exports.

In `@examples/puzzletron/distillation/block_config_utils.py`:
- Line 177: The parameter moe_shared_expert_intermediate_omitted is defined and
documented in the docstring but never used in the function body. Either
implement the MoE handling logic described in the docstring that should use this
flag to control whether shared experts are cleared when the key was omitted from
JSON block_configs, or remove the parameter entirely if the functionality is not
needed. If implementing, add the conditional logic in the MoE handling section
that checks this flag and clears shared experts accordingly.
- Around line 130-137: The condition checking if converter_cls is None will
never be true because ConverterFactory.get() returns the input string itself
when the converter is not found, not None. Replace the `if converter_cls is
None:` check with a condition that verifies converter_cls is actually a class or
callable (not a string), such as checking if it is not a string type or if it is
callable. This will properly catch cases where an invalid converter_name is
provided and prevent the AttributeError that would occur when trying to call
create_block_configs_from_main_config() on a string object.

In `@examples/puzzletron/distillation/export_to_hf.py`:
- Around line 209-211: Update the help text string that describes the student
model key argument to reference the correct flag name. Change the mention of
`--student-checkpoint` to `--student-hf-checkpoint` in the help text to match
the actual CLI argument name, ensuring the help documentation is accurate and
not misleading to users.
- Around line 155-166: The copy loop that iterates over file names and uses
shutil.copy() will crash when student_path is a remote HF repo id or when
optional files like chat_template.jinja are missing. Before attempting to copy
each file in the loop where src is constructed as Path(student_path) / fname,
check if the source file exists using Path.exists(). For optional configuration
files, skip the copy operation gracefully if the source file doesn't exist
instead of raising FileNotFoundError. Ensure the success message is only printed
when a file was actually copied.

In `@examples/puzzletron/distillation/provider_patch.py`:
- Line 210: Replace all implicit relative imports with explicit relative imports
using dot notation to ensure the modules can be imported correctly regardless of
execution context. In examples/puzzletron/distillation/provider_patch.py at
lines 210 and 329, change `from layer_patchers import mbridge_patcher` to `from
.layer_patchers import mbridge_patcher`. In
examples/puzzletron/distillation/_common.py at lines 105-107, change `from
block_config_utils import load_block_configs` to `from .block_config_utils
import load_block_configs`. In examples/puzzletron/distillation/distill.py at
lines 467-473, change all import statements like `from model_bridge_patch import
...` and `from provider_patch import ...` to use explicit relative imports by
adding a leading dot prefix (e.g., `from .model_bridge_patch import ...` and
`from .provider_patch import ...`).

In `@examples/puzzletron/distillation/README.md`:
- Around line 169-186: The opening code fence in the README.md file does not
include a language identifier. Add a language specifier (such as `text`) to the
opening triple backticks of the fenced code block containing the directory tree
structure to ensure markdown linting compliance and maintain documentation
tooling standards.

---

Nitpick comments:
In `@examples/puzzletron/distillation/block_config_utils.py`:
- Around line 42-52: The module is missing an `__all__` declaration to
explicitly define its public API. Add an `__all__` list after the import
statements (after the logger initialization) that includes the four public
symbols: MCoreLayerOverrides, load_block_configs,
block_config_to_mcore_overrides, and get_overrides_for_layer. This makes the
module's public interface explicit and aligns with coding guidelines.

In `@examples/puzzletron/distillation/distill.py`:
- Around line 616-623: The finally block in the distill function that calls
torch.distributed.destroy_process_group() is redundant because this cleanup is
already handled by run_entrypoint() in _common.py. Remove the entire finally
block (which includes the is_initialized() check and destroy_process_group()
call) from the distill function, as the distributed process group teardown will
still be properly managed at the higher level in run_entrypoint().
- Around line 151-152: The star imports from
modelopt.torch.puzzletron.anymodel.converter and
modelopt.torch.puzzletron.anymodel.model_descriptor pollute the namespace and
obscure which symbols are actually used. Replace both star import statements
with explicit imports of only the specific symbols needed from each module.
Identify which classes, functions, or constants from converter and
model_descriptor are actually used in this file and import them by name instead
of using the wildcard syntax.

In `@examples/puzzletron/distillation/export_to_hf.py`:
- Around line 112-113: Move the deferred imports out of the main() function and
place them at the module scope with other imports at the top of the file. The
import statement `from provider_patch import apply_distillation_patch,
apply_patch` at line 112-113 and the other deferred import at line 135 should
both be relocated to the module-level import section unless there is a
documented reason (such as circular dependency, optional dependency, or heavy
import deferral) for keeping them deferred—in which case, add an explicit inline
comment justifying the deferred import at that location.

In `@examples/puzzletron/distillation/gpt_oss_bridge.py`:
- Around line 188-202: Remove the dead code that calculates num_experts,
ep_size, and experts_per_rank from self.hf_config.num_local_experts before the
try-except block. These variables are shadowed and immediately recalculated
after the extract_layer_idx_from_param and extract_expert_number_from_param
calls using block_config instead, making the initial calculations at lines
188-190 unused. Delete the redundant variable assignments and keep only the
recalculation that uses block_config["ffn"]["moe"]["num_local_experts"].

In `@examples/puzzletron/distillation/layer_patchers.py`:
- Around line 43-59: The module is missing an `__all__` declaration to
explicitly define its public API. After the logger initialization line where
`logger = logging.getLogger(__name__)` is defined, add an `__all__` list that
includes the three exported symbols: NoOpWithBias, NoOpRMSNorm, and
mbridge_patcher. This declaration clarifies the module's public surface and
follows coding guidelines for proper API documentation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 18de3e9f-6e82-4675-b39f-0cf70e941dfe

📥 Commits

Reviewing files that changed from the base of the PR and between 9f6e8fd and 48cb89a.

📒 Files selected for processing (13)
  • examples/puzzletron/README.md
  • examples/puzzletron/distillation/README.md
  • examples/puzzletron/distillation/__init__.py
  • examples/puzzletron/distillation/_common.py
  • examples/puzzletron/distillation/block_config_utils.py
  • examples/puzzletron/distillation/distill.py
  • examples/puzzletron/distillation/export_to_hf.py
  • examples/puzzletron/distillation/gpt_oss_bridge.py
  • examples/puzzletron/distillation/kd-container-default.yaml
  • examples/puzzletron/distillation/layer_patchers.py
  • examples/puzzletron/distillation/model_bridge_patch.py
  • examples/puzzletron/distillation/provider_patch.py
  • pyproject.toml

Comment on lines +36 to +55
from block_config_utils import (
MCoreLayerOverrides,
block_config_to_mcore_overrides,
get_overrides_for_layer,
load_block_configs,
)
from layer_patchers import NoOpWithBias, mbridge_patcher
from provider_patch import apply_patch, remove_patch, set_provider_block_configs

__all__ = [
"MCoreLayerOverrides",
"NoOpWithBias",
"apply_patch",
"block_config_to_mcore_overrides",
"get_overrides_for_layer",
"load_block_configs",
"mbridge_patcher",
"remove_patch",
"set_provider_block_configs",
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use package-relative star re-exports in __init__.py to avoid import breakage.

Line 36-43 uses absolute imports, which can break when importing examples.puzzletron.distillation as a package. This also misses the repository __init__.py API pattern.

Proposed fix
+__all__ = [
+    "MCoreLayerOverrides",
+    "NoOpWithBias",
+    "apply_patch",
+    "block_config_to_mcore_overrides",
+    "get_overrides_for_layer",
+    "load_block_configs",
+    "mbridge_patcher",
+    "remove_patch",
+    "set_provider_block_configs",
+]
+
-from block_config_utils import (
-    MCoreLayerOverrides,
-    block_config_to_mcore_overrides,
-    get_overrides_for_layer,
-    load_block_configs,
-)
-from layer_patchers import NoOpWithBias, mbridge_patcher
-from provider_patch import apply_patch, remove_patch, set_provider_block_configs
-
-__all__ = [
-    "MCoreLayerOverrides",
-    "NoOpWithBias",
-    "apply_patch",
-    "block_config_to_mcore_overrides",
-    "get_overrides_for_layer",
-    "load_block_configs",
-    "mbridge_patcher",
-    "remove_patch",
-    "set_provider_block_configs",
-]
+from .block_config_utils import *  # re-export package API
+from .layer_patchers import *  # re-export package API
+from .provider_patch import *  # re-export package API

As per coding guidelines, package __init__.py files should define __all__ and re-export using from .module import *.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/__init__.py` around lines 36 - 55, The
import statements at lines 36-43 use absolute imports (block_config_utils,
layer_patchers, provider_patch) which can break when importing the distillation
package as a whole. Convert these to package-relative imports by prepending a
dot (.) before each module name (from .block_config_utils import ..., from
.layer_patchers import ..., from .provider_patch import ...). This ensures the
imports are resolved relative to the current package and follows the repository
__init__.py API pattern using relative star re-exports.

Source: Coding guidelines

Comment on lines +130 to +137
converter_cls = ConverterFactory.get(converter_name)
if converter_cls is None:
logger.warning(
"Unknown AnyModel converter '%s'. Cannot generate block_configs. "
"Valid converters can be listed with ConverterFactory.list().",
converter_name,
)
return None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

ConverterFactory.get() returns the input string when not found, not None.

Per context snippet 2, ConverterFactory.get(value) returns value itself (the string) if value is a string not in CLASS_MAPPING. This means converter_cls is None will never be True for a string input, and the code will proceed to line 139 where converter_cls.create_block_configs_from_main_config(...) will fail with an AttributeError on a string.

Suggested fix
     converter_cls = ConverterFactory.get(converter_name)
-    if converter_cls is None:
+    if converter_cls is None or converter_cls is converter_name:
         logger.warning(
             "Unknown AnyModel converter '%s'. Cannot generate block_configs. "
             "Valid converters can be listed with ConverterFactory.list().",
             converter_name,
         )
         return None

Alternatively, check if converter_cls is a class/callable:

     converter_cls = ConverterFactory.get(converter_name)
-    if converter_cls is None:
+    if not callable(converter_cls) or isinstance(converter_cls, str):
         logger.warning(
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/block_config_utils.py` around lines 130 -
137, The condition checking if converter_cls is None will never be true because
ConverterFactory.get() returns the input string itself when the converter is not
found, not None. Replace the `if converter_cls is None:` check with a condition
that verifies converter_cls is actually a class or callable (not a string), such
as checking if it is not a string type or if it is callable. This will properly
catch cases where an invalid converter_name is provided and prevent the
AttributeError that would occur when trying to call
create_block_configs_from_main_config() on a string object.

num_attention_heads: int,
hidden_size: int,
*,
moe_shared_expert_intermediate_omitted: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Parameter moe_shared_expert_intermediate_omitted is documented but never used.

The docstring explains this parameter should cause Megatron to clear shared experts when the key was omitted from JSON block_configs, but there's no code in the function body that uses this flag. This appears to be an incomplete implementation.

If the behavior described in the docstring is intended, the MoE handling section should use this flag:

Suggested implementation
                     _set_if_not_none(
                         config_overrides,
                         "moe_shared_expert_intermediate_size",
                         getattr(moe_cfg, "shared_expert_intermediate_dim", None),
                     )
+                    # If the key was omitted from JSON, we must explicitly disable shared experts
+                    # to avoid MoEConfig defaulting to 8192.
+                    if moe_shared_expert_intermediate_omitted:
+                        config_overrides.setdefault("moe_shared_expert_intermediate_size", None)

Or remove the parameter if not needed.

Also applies to: 198-211

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/block_config_utils.py` at line 177, The
parameter moe_shared_expert_intermediate_omitted is defined and documented in
the docstring but never used in the function body. Either implement the MoE
handling logic described in the docstring that should use this flag to control
whether shared experts are cleared when the key was omitted from JSON
block_configs, or remove the parameter entirely if the functionality is not
needed. If implementing, add the conditional logic in the MoE handling section
that checks this flag and clears shared experts accordingly.

Comment on lines +155 to +166
print_rank_0(f" Copying configs from: {student_path}")
for fname in [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"chat_template.jinja",
]:
src = Path(student_path) / fname
dst = Path(args.output_hf_checkpoint) / fname
shutil.copy(src, dst)
print_rank_0(f" ✅ Copied {fname}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Handle non-local and missing source files during artifact copy.

This block will raise FileNotFoundError in documented flows: when --student-hf-checkpoint is omitted (so student_path is an HF repo id, not a local dir) or when optional files (e.g., chat_template.jinja) are absent. That turns a successful export into a hard failure.

💡 Proposed fix
-    print_rank_0(f"  Copying configs from: {student_path}")
-    for fname in [
+    print_rank_0(f"  Copying configs from: {student_path}")
+    source_dir = Path(student_path)
+    if not source_dir.is_dir():
+        print_rank_0(
+            "  ⚠️ Skipping tokenizer/config copy because --student-hf-checkpoint "
+            "is not a local directory."
+        )
+    else:
+        for fname in [
         "config.json",
         "tokenizer.json",
         "tokenizer_config.json",
         "special_tokens_map.json",
         "chat_template.jinja",
-    ]:
-        src = Path(student_path) / fname
-        dst = Path(args.output_hf_checkpoint) / fname
-        shutil.copy(src, dst)
-        print_rank_0(f"  ✅ Copied {fname}")
+        ]:
+            src = source_dir / fname
+            if not src.exists():
+                print_rank_0(f"  ⚠️ Skipping missing {fname}")
+                continue
+            dst = Path(args.output_hf_checkpoint) / fname
+            shutil.copy(src, dst)
+            print_rank_0(f"  ✅ Copied {fname}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/export_to_hf.py` around lines 155 - 166, The
copy loop that iterates over file names and uses shutil.copy() will crash when
student_path is a remote HF repo id or when optional files like
chat_template.jinja are missing. Before attempting to copy each file in the loop
where src is constructed as Path(student_path) / fname, check if the source file
exists using Path.exists(). For optional configuration files, skip the copy
operation gracefully if the source file doesn't exist instead of raising
FileNotFoundError. Ensure the success message is only printed when a file was
actually copied.

Comment on lines +209 to +211
"Student model key. Determines the HuggingFace model ID (used when "
"--student-checkpoint is omitted) and the AnyModel converter for block_configs."
),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix stale flag name in CLI help text.

Line 210 references --student-checkpoint, but the actual argument is --student-hf-checkpoint. This is user-facing and misleading in --help.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/export_to_hf.py` around lines 209 - 211,
Update the help text string that describes the student model key argument to
reference the correct flag name. Change the mention of `--student-checkpoint` to
`--student-hf-checkpoint` in the help text to match the actual CLI argument
name, ensuring the help documentation is accurate and not misleading to users.

2. Detects whether this is a Mamba/hybrid provider (to set ``apply_no_ops``).
3. Activates ``mbridge_patcher`` and delegates to the original ``provide()``.
"""
from layer_patchers import mbridge_patcher

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Implicit relative imports may fail depending on execution context. All three files use implicit relative imports (e.g., from layer_patchers import ... instead of from .layer_patchers import ...). These work when the script directory is in sys.path but will fail if the modules are imported as part of a package from a different working directory.

  • examples/puzzletron/distillation/provider_patch.py#L210-L210: Change from layer_patchers import mbridge_patcher to from .layer_patchers import mbridge_patcher (also at line 329).
  • examples/puzzletron/distillation/_common.py#L105-L107: Change from block_config_utils import load_block_configs to from .block_config_utils import load_block_configs.
  • examples/puzzletron/distillation/distill.py#L467-L473: Change from model_bridge_patch import ... and from provider_patch import ... to use explicit relative imports with the dot prefix.
📍 Affects 3 files
  • examples/puzzletron/distillation/provider_patch.py#L210-L210 (this comment)
  • examples/puzzletron/distillation/_common.py#L105-L107
  • examples/puzzletron/distillation/distill.py#L467-L473
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/provider_patch.py` at line 210, Replace all
implicit relative imports with explicit relative imports using dot notation to
ensure the modules can be imported correctly regardless of execution context. In
examples/puzzletron/distillation/provider_patch.py at lines 210 and 329, change
`from layer_patchers import mbridge_patcher` to `from .layer_patchers import
mbridge_patcher`. In examples/puzzletron/distillation/_common.py at lines
105-107, change `from block_config_utils import load_block_configs` to `from
.block_config_utils import load_block_configs`. In
examples/puzzletron/distillation/distill.py at lines 467-473, change all import
statements like `from model_bridge_patch import ...` and `from provider_patch
import ...` to use explicit relative imports by adding a leading dot prefix
(e.g., `from .model_bridge_patch import ...` and `from .provider_patch import
...`).

Comment on lines +169 to +186
```
examples/puzzletron/distillation/
├── README.md # this file
├── distill.py # KD entrypoint (HF -> Bridge -> distill loop)
├── export_to_hf.py # MCore -> HF checkpoint export
├── _common.py # MODEL_REGISTRY + shared HF/Bridge helpers
├── block_config_utils.py # Per-layer block_configs loader & translation
├── layer_patchers.py # mbridge_patcher: per-layer MCore overrides
├── provider_patch.py # ModelProviderMixin / DistillationProvider patches
├── model_bridge_patch.py # Misc Bridge model-class patches
├── gpt_oss_bridge.py # Patched Megatron-Bridge GPT-OSS bridge (overlay)
├── kd-container-default.yaml # Default ConfigContainer overrides
├── kd-container-{llama,nemotron3,qwen}.yaml # Per-model recipes
├── kd-dummy.yaml # Smoke-test config
├── data_prep_{llama,nemotron3,qwen3}.ipynb # Dataset tokenization notebooks
├── interactive.sh # Reference srun + torchrun command snippets
└── run_validate.py # Standalone validation-loss runner
```

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add a language identifier to the fenced code block.

Line 169 opens a code fence without a language specifier (```). Please tag it (for example text) to satisfy markdown linting and keep docs tooling clean.

🧰 Tools
🪛 markdownlint-cli2 (0.22.1)

[warning] 169-169: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/puzzletron/distillation/README.md` around lines 169 - 186, The
opening code fence in the README.md file does not include a language identifier.
Add a language specifier (such as `text`) to the opening triple backticks of the
fenced code block containing the directory tree structure to ensure markdown
linting compliance and maintain documentation tooling standards.

Source: Linters/SAST tools

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have been using https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/megatron_bridge/distill.py for puzzletron anymodel distillation already. Why do we now need all this extra patching logic?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add MBridge dsitillation for puzzletron

2 participants