fix(jax): stabilize repflow dynamic selection export#5533
Conversation
📝 WalkthroughWalkthroughAdds a backend-internal ChangesStatic dynamic selection for RepFlows
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/universal/dpmodel/descriptor/test_descriptor.py (1)
915-933: ⚡ Quick winCover the
use_loc_mapping=Falsestatic-index branch too.
_get_static_graph_index()changes its indexing stride whenuse_loc_mappingis disabled, but this helper always constructs the default mapped configuration. Extending this test to run both values would cover the second branch of the new static-dynamic layout logic that JAX now enables by default.♻️ Suggested test expansion
- def _make_dpa3(self, use_static_dynamic_sel: bool) -> DescrptDPA3: + def _make_dpa3( + self, + use_static_dynamic_sel: bool, + *, + use_loc_mapping: bool, + ) -> DescrptDPA3: # The switch is intentionally class-level and internal, so tests toggle # it only around construction and then restore the previous backend mode. old_use_static_dynamic_sel = DescrptBlockRepflows._use_static_dynamic_sel DescrptBlockRepflows._use_static_dynamic_sel = use_static_dynamic_sel try: return DescrptDPA3( **DescriptorParamDPA3( self.nt, self.rcut, self.rcut_smth, self.sel, ["O", "H"], smooth_edge_update=True, use_dynamic_sel=True, + use_loc_mapping=use_loc_mapping, ) ) finally: DescrptBlockRepflows._use_static_dynamic_sel = old_use_static_dynamic_sel def test_static_dynamic_sel_matches_packed_dynamic_sel(self) -> None: - packed = self._make_dpa3(False) - static = self._make_dpa3(True) + for use_loc_mapping in (True, False): + packed = self._make_dpa3(False, use_loc_mapping=use_loc_mapping) + static = self._make_dpa3(True, use_loc_mapping=use_loc_mapping) - packed_out = packed( - self.coord_ext, - self.atype_ext, - self.nlist, - mapping=self.mapping, - ) - static_out = static( - self.coord_ext, - self.atype_ext, - self.nlist, - mapping=self.mapping, - ) + packed_out = packed( + self.coord_ext, + self.atype_ext, + self.nlist, + mapping=self.mapping, + ) + static_out = static( + self.coord_ext, + self.atype_ext, + self.nlist, + mapping=self.mapping, + ) - np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol) - np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol) + np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol) + np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol) - valid_edge_mask = np.reshape(self.nlist != -1, (-1,)) - assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel) - np.testing.assert_allclose( - packed_out[2], static_out[2][valid_edge_mask], atol=self.atol - ) - np.testing.assert_allclose( - packed_out[3], static_out[3][valid_edge_mask], atol=self.atol - ) - np.testing.assert_allclose( - packed_out[4], static_out[4][valid_edge_mask], atol=self.atol - ) + valid_edge_mask = np.reshape(self.nlist != -1, (-1,)) + assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel) + np.testing.assert_allclose( + packed_out[2], static_out[2][valid_edge_mask], atol=self.atol + ) + np.testing.assert_allclose( + packed_out[3], static_out[3][valid_edge_mask], atol=self.atol + ) + np.testing.assert_allclose( + packed_out[4], static_out[4][valid_edge_mask], atol=self.atol + )Also applies to: 935-968
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py` around lines 915 - 933, The _make_dpa3 helper method currently only constructs the default mapped configuration, but _get_static_graph_index() has different indexing behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method to accept a parameter for use_loc_mapping (similar to how it accepts use_static_dynamic_sel) and ensure it applies this parameter when constructing DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including those at lines 935-968) to run test assertions with both use_loc_mapping=True and use_loc_mapping=False so that both branches of the static-dynamic layout logic are covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py`:
- Around line 915-933: The _make_dpa3 helper method currently only constructs
the default mapped configuration, but _get_static_graph_index() has different
indexing behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method
to accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: b5b5ca58-1ff0-4fbd-beea-b662619d0e7b
📒 Files selected for processing (4)
deepmd/dpmodel/descriptor/repflows.pydeepmd/dpmodel/utils/network.pydeepmd/jax/descriptor/repflows.pysource/tests/universal/dpmodel/descriptor/test_descriptor.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5533 +/- ##
==========================================
- Coverage 82.18% 82.18% -0.01%
==========================================
Files 890 890
Lines 101357 101399 +42
Branches 4240 4240
==========================================
+ Hits 83301 83335 +34
- Misses 16754 16761 +7
- Partials 1302 1303 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
Validation
Summary by CodeRabbit
Chores
Tests