DPO integration test - logits comparision against Pytorch#4026
DPO integration test - logits comparision against Pytorch#4026igorts-git wants to merge 1 commit into
Conversation
0e12528 to
4d24b21
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
269b98f to
f66353b
Compare
5a31620 to
d7f0162
Compare
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details. |
d7f0162 to
93c784f
Compare
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces a robust integration test verifying JAX DPO correctness against Hugging Face's canonical TRL implementation. By running a structurally identical tiny architecture and saving golden results, the test achieves high verification confidence while maintaining zero PyTorch/TRL dependencies in CI. The implementation quality is extremely high, with clear documentation and strict adherence to project standards.
🔍 General Feedback
- Symmetric Architecture Design: The miniaturized 2-layer model approach running entirely on CPU is excellent for avoiding GPU/TPU non-determinism while remaining extremely fast.
- Well-documented Tolerances: The rigorous calibration of DPO loss and log probability tolerances across multiple random seeds is an exceptional testing practice and is highly commendable.
- Robust Environment Handling: Clean setup and restoration of JAX environment configurations and training hooks prevent state leakage to neighboring tests.
- Suggested Improvements: A few minor optimizations around dynamic temporary directories, offline/hermetic tokenizer resolution, and explicit class-level attribute mutations have been suggested for maximum reliability.
15384b6 to
933ee07
Compare
This commit introduces a JAX DPO correctness integration test that validates JAX DPO training step metrics (loss, margin, chosen/rejected logprobs) against stored golden outputs. It also includes the CPU-only JAX/PyTorch parallel validation script used to verify parity against Hugging Face TRL and generate the remote-canonical golden assets.
933ee07 to
ff31dbe
Compare
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details. |
Description
Add a logits comparison test for MaxText DPO vs HuggingFace TRL implementation.
The new test consists of two parts:
generate_dpo_golden_data_and_compare_pytorch_logits.py- compares DPO in MaxText vs HF TRL and saves golden data json file.This separation allows us to not have a dependency on TRL in CI.
The tests are running on CPU to reduce the GPU/TPU floating point differences. We use a tiny 2-layer qwen2-like model to make sure that this unit test runs fast.
A separate PR will be sent out to add a similar test for ORPO. The Tunix implementation would first need to be adjusted a little bit to match the canonical implementation (see google/tunix#1574).
Tests
Ran the test manually.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.