Skip to content

fix(jax): stabilize repflow dynamic selection export#5533

Open
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/jax-repflow-static-dynamic-sel
Open

fix(jax): stabilize repflow dynamic selection export#5533
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/jax-repflow-static-dynamic-sel

Conversation

@njzjz

@njzjz njzjz commented Jun 14, 2026

Copy link
Copy Markdown
Member

Summary

  • add an internal fixed-capacity dynamic-selection layout for repflows so JAX/jax2tf export avoids runtime-sized edge/angle tensors
  • skip unnecessary bincount in sum-only aggregate calls with a known owner count
  • add regression coverage comparing compact and static dynamic selection outputs

Validation

  • ruff check .
  • ruff format .
  • pytest source/tests/universal/dpmodel/descriptor/test_descriptor.py::TestDPA3StaticDynamicSelDP::test_static_dynamic_sel_matches_packed_dynamic_sel -q
  • dp convert-backend DPA-3.2-5M.pth DPA-3.2-5M.savedmodel

Summary by CodeRabbit

  • Chores

    • Optimized descriptor block handling for dynamic neighbor selection with fixed-capacity layout support.
    • Refactored aggregation computation logic for enhanced efficiency.
    • Updated JAX backend to support improved export functionality.
  • Tests

    • Added validation tests for dynamic selection implementation consistency.

@coderabbitai

coderabbitai Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a backend-internal _use_static_dynamic_sel flag to DescrptBlockRepflows and RepFlowLayer, enabling a fixed-capacity (padded) execution path for edges and angles as an alternative to compact boolean-mask compaction. A new _get_static_graph_index helper builds the padded index tensors. The JAX subclasses override the flag to True. The aggregate utility is refactored to skip bin_count computation for pure sum reductions. An equivalence test verifies both paths produce matching outputs.

Changes

Static dynamic selection for RepFlows

Layer / File(s) Summary
Flag definition, instance snapshot, and propagation
deepmd/dpmodel/descriptor/repflows.py
Defines _use_static_dynamic_sel as a class-level bool on DescrptBlockRepflows and RepFlowLayer, snapshots it onto instances in __init__, propagates it to owned RepFlowLayer instances at construction time, and re-copies it to restored layers after deserialization.
_get_static_graph_index helper and call() branching
deepmd/dpmodel/descriptor/repflows.py
Implements _get_static_graph_index returning fixed-capacity edge_index (2×n_edges) and angle_index (3×n_angles) with padded slots. Extends DescrptBlockRepflows.call() to branch on the flag: static path reshapes to flattened fixed capacities with (j,k) angle gating; compact path keeps get_graph_index with boolean-mask compaction. Updates RepFlowLayer.call() to read n_edge from h2.shape[0] in static mode vs a masked sum in compact mode.
aggregate: conditional bin_count computation
deepmd/dpmodel/utils/network.py
Refactors aggregate to compute bin_count only when num_owner is absent or averaging is requested; allocates output directly with (num_owner, feature_dim) and asserts bin_count is not None before the averaging divide.
JAX backend override
deepmd/jax/descriptor/repflows.py
Sets _use_static_dynamic_sel = True on both DescrptBlockRepflows and RepFlowLayer JAX subclasses to activate fixed-capacity layout for JAX/jax2tf export.
Equivalence test
source/tests/universal/dpmodel/descriptor/test_descriptor.py
Adds TestDPA3StaticDynamicSelDP with a _make_dpa3 helper that toggles _use_static_dynamic_sel at construction, and a test that asserts packed dynamic and static dynamic outputs match by masking padded slots via nlist != -1.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5355: Modifies the same DescrptBlockRepflows/RepFlowLayer classes in deepmd/dpmodel/descriptor/repflows.py to add runtime behavior flags that affect call() logic, directly overlapping with the flag propagation and call() branching in this PR.

Suggested labels

Python

Suggested reviewers

  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix(jax): stabilize repflow dynamic selection export' clearly and specifically describes the main change—stabilizing repflow dynamic selection export for JAX, which is the primary objective of the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/universal/dpmodel/descriptor/test_descriptor.py (1)

915-933: ⚡ Quick win

Cover the use_loc_mapping=False static-index branch too.

_get_static_graph_index() changes its indexing stride when use_loc_mapping is disabled, but this helper always constructs the default mapped configuration. Extending this test to run both values would cover the second branch of the new static-dynamic layout logic that JAX now enables by default.

♻️ Suggested test expansion
-    def _make_dpa3(self, use_static_dynamic_sel: bool) -> DescrptDPA3:
+    def _make_dpa3(
+        self,
+        use_static_dynamic_sel: bool,
+        *,
+        use_loc_mapping: bool,
+    ) -> DescrptDPA3:
         # The switch is intentionally class-level and internal, so tests toggle
         # it only around construction and then restore the previous backend mode.
         old_use_static_dynamic_sel = DescrptBlockRepflows._use_static_dynamic_sel
         DescrptBlockRepflows._use_static_dynamic_sel = use_static_dynamic_sel
         try:
             return DescrptDPA3(
                 **DescriptorParamDPA3(
                     self.nt,
                     self.rcut,
                     self.rcut_smth,
                     self.sel,
                     ["O", "H"],
                     smooth_edge_update=True,
                     use_dynamic_sel=True,
+                    use_loc_mapping=use_loc_mapping,
                 )
             )
         finally:
             DescrptBlockRepflows._use_static_dynamic_sel = old_use_static_dynamic_sel

     def test_static_dynamic_sel_matches_packed_dynamic_sel(self) -> None:
-        packed = self._make_dpa3(False)
-        static = self._make_dpa3(True)
+        for use_loc_mapping in (True, False):
+            packed = self._make_dpa3(False, use_loc_mapping=use_loc_mapping)
+            static = self._make_dpa3(True, use_loc_mapping=use_loc_mapping)
 
-        packed_out = packed(
-            self.coord_ext,
-            self.atype_ext,
-            self.nlist,
-            mapping=self.mapping,
-        )
-        static_out = static(
-            self.coord_ext,
-            self.atype_ext,
-            self.nlist,
-            mapping=self.mapping,
-        )
+            packed_out = packed(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+                mapping=self.mapping,
+            )
+            static_out = static(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+                mapping=self.mapping,
+            )

-        np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol)
-        np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol)
+            np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol)
+            np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol)

-        valid_edge_mask = np.reshape(self.nlist != -1, (-1,))
-        assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel)
-        np.testing.assert_allclose(
-            packed_out[2], static_out[2][valid_edge_mask], atol=self.atol
-        )
-        np.testing.assert_allclose(
-            packed_out[3], static_out[3][valid_edge_mask], atol=self.atol
-        )
-        np.testing.assert_allclose(
-            packed_out[4], static_out[4][valid_edge_mask], atol=self.atol
-        )
+            valid_edge_mask = np.reshape(self.nlist != -1, (-1,))
+            assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel)
+            np.testing.assert_allclose(
+                packed_out[2], static_out[2][valid_edge_mask], atol=self.atol
+            )
+            np.testing.assert_allclose(
+                packed_out[3], static_out[3][valid_edge_mask], atol=self.atol
+            )
+            np.testing.assert_allclose(
+                packed_out[4], static_out[4][valid_edge_mask], atol=self.atol
+            )

Also applies to: 935-968

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py` around lines
915 - 933, The _make_dpa3 helper method currently only constructs the default
mapped configuration, but _get_static_graph_index() has different indexing
behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method to
accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py`:
- Around line 915-933: The _make_dpa3 helper method currently only constructs
the default mapped configuration, but _get_static_graph_index() has different
indexing behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method
to accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b5b5ca58-1ff0-4fbd-beea-b662619d0e7b

📥 Commits

Reviewing files that changed from the base of the PR and between c0b0319 and b338bb1.

📒 Files selected for processing (4)
  • deepmd/dpmodel/descriptor/repflows.py
  • deepmd/dpmodel/utils/network.py
  • deepmd/jax/descriptor/repflows.py
  • source/tests/universal/dpmodel/descriptor/test_descriptor.py

@codecov

codecov Bot commented Jun 14, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 87.27273% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.18%. Comparing base (c0b0319) to head (b338bb1).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/utils/network.py 30.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5533      +/-   ##
==========================================
- Coverage   82.18%   82.18%   -0.01%     
==========================================
  Files         890      890              
  Lines      101357   101399      +42     
  Branches     4240     4240              
==========================================
+ Hits        83301    83335      +34     
- Misses      16754    16761       +7     
- Partials     1302     1303       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested review from iProzd and wanghan-iapcm June 14, 2026 17:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant