Skip to content

DPO integration test - logits comparision against Pytorch#4026

Open
igorts-git wants to merge 1 commit into
mainfrom
igorts/dpo-logits-test
Open

DPO integration test - logits comparision against Pytorch#4026
igorts-git wants to merge 1 commit into
mainfrom
igorts/dpo-logits-test

Conversation

@igorts-git

@igorts-git igorts-git commented May 30, 2026

Copy link
Copy Markdown
Collaborator

Description

Add a logits comparison test for MaxText DPO vs HuggingFace TRL implementation.

The new test consists of two parts:

  1. generate_dpo_golden_data_and_compare_pytorch_logits.py - compares DPO in MaxText vs HF TRL and saves golden data json file.
  2. A unit test that only runs MaxText DPO and compares the results against the golden data in the json files.

This separation allows us to not have a dependency on TRL in CI.
The tests are running on CPU to reduce the GPU/TPU floating point differences. We use a tiny 2-layer qwen2-like model to make sure that this unit test runs fast.

A separate PR will be sent out to add a similar test for ORPO. The Tunix implementation would first need to be adjusted a little bit to match the canonical implementation (see google/tunix#1574).

Tests

Ran the test manually.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@igorts-git igorts-git force-pushed the igorts/dpo-logits-test branch from 0e12528 to 4d24b21 Compare May 30, 2026 07:09
@codecov

codecov Bot commented May 30, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@igorts-git igorts-git force-pushed the igorts/dpo-logits-test branch 2 times, most recently from 269b98f to f66353b Compare June 2, 2026 20:26
@igorts-git igorts-git force-pushed the igorts/dpo-logits-test branch 10 times, most recently from 5a31620 to d7f0162 Compare June 11, 2026 22:14
@igorts-git igorts-git changed the title Dpo integration test - logits comparision against Pytorch DPO integration test - logits comparision against Pytorch Jun 11, 2026
@igorts-git igorts-git marked this pull request as ready for review June 11, 2026 22:35
@github-actions

Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces a robust integration test verifying JAX DPO correctness against Hugging Face's canonical TRL implementation. By running a structurally identical tiny architecture and saving golden results, the test achieves high verification confidence while maintaining zero PyTorch/TRL dependencies in CI. The implementation quality is extremely high, with clear documentation and strict adherence to project standards.

🔍 General Feedback

  • Symmetric Architecture Design: The miniaturized 2-layer model approach running entirely on CPU is excellent for avoiding GPU/TPU non-determinism while remaining extremely fast.
  • Well-documented Tolerances: The rigorous calibration of DPO loss and log probability tolerances across multiple random seeds is an exceptional testing practice and is highly commendable.
  • Robust Environment Handling: Clean setup and restoration of JAX environment configurations and training hooks prevent state leakage to neighboring tests.
  • Suggested Improvements: A few minor optimizations around dynamic temporary directories, offline/hermetic tokenizer resolution, and explicit class-level attribute mutations have been suggested for maximum reliability.

Comment thread tests/post_training/integration/dpo_correctness_base.py Outdated
Comment thread tests/assets/logits_generation/dpo_pytorch_helpers.py Outdated
Comment thread tests/post_training/integration/dpo_correctness_base.py Outdated
@igorts-git igorts-git force-pushed the igorts/dpo-logits-test branch 2 times, most recently from 15384b6 to 933ee07 Compare June 12, 2026 20:41
This commit introduces a JAX DPO correctness integration test that validates JAX DPO training step metrics (loss, margin, chosen/rejected logprobs) against stored golden outputs. It also includes the CPU-only JAX/PyTorch parallel validation script used to verify parity against Hugging Face TRL and generate the remote-canonical golden assets.
@github-actions

Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant