Skip to content

fix(pt): enable second-order autograd for tabulate descriptors#5537

Open
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/pt-tabulate-second-order-autograd
Open

fix(pt): enable second-order autograd for tabulate descriptors#5537
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/pt-tabulate-second-order-autograd

Conversation

@njzjz

@njzjz njzjz commented Jun 15, 2026

Copy link
Copy Markdown
Member

Summary

  • wrap PyTorch tabulate descriptor first-derivative kernels in autograd Functions
  • connect se_a, se_atten, se_t, se_r, and se_t_tebd backward paths to existing grad-grad kernels
  • add second-order backward regression tests for all affected tabulate descriptor ops

Fixes #4994.

Tests

  • pytest source/tests/pt/test_tabulate_fusion_se_a.py source/tests/pt/test_tabulate_fusion_se_atten.py source/tests/pt/test_tabulate_fusion_se_r.py source/tests/pt/test_tabulate_fusion_se_t.py source/tests/pt/test_tabulate_fusion_se_t_tebd.py -q
  • ruff check .
  • ruff format .
  • commit hook suite, including clang-format

Summary by CodeRabbit

  • New Features

    • Implemented second-order gradient support for tabulate fusion operations (embeddings and attention), enabling advanced gradient-based optimization techniques
  • Tests

    • Added second-order backward pass tests across multiple operations to verify higher-order differentiation accuracy and numerical stability

@coderabbitai

coderabbitai Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds second-order autograd support to all five PyTorch TabulateFusion embedding paths (SeA, SeAtten, SeT, SeR, SeTTebd) by introducing GradOp/GradGradOp torch::autograd::Function wrappers. Each Op's backward_t now delegates via apply() instead of calling *GradForward kernels directly. Five corresponding test_second_order_backward test methods are added.

Changes

Second-order autograd for all FusionSe* ops

Layer / File(s) Summary
SeA GradOp and GradGradOp introduction and wiring
source/op/pt/tabulate_multi_device.cc
Introduces TabulateFusionSeAGradOp (forward_t calls TabulateFusionSeAGradForward, backward_t calls TabulateFusionSeAGradGradForward) and TabulateFusionSeAGradGradOp (dtype-dispatched forward_t computing dz_dy_tensor). TabulateFusionSeAOp::backward_t is rewired to delegate to TabulateFusionSeAGradOp::apply(...).
SeAtten GradOp refactor and is_sorted persistence
source/op/pt/tabulate_multi_device.cc
Refactors TabulateFusionSeAttenGradOp: forward_t saves is_sorted into ctx->saved_data; backward_t retrieves it, applies zeros_like for undefined grad outputs, and calls TabulateFusionSeAGradGradForward. TabulateFusionSeAttenOp::forward_t adds ctx->saved_data["is_sorted"]; backward_t delegates to TabulateFusionSeAttenGradOp::apply(...).
SeT, SeR, and SeTTebd GradOp wrappers
source/op/pt/tabulate_multi_device.cc
Introduces TabulateFusionSeTGradOp, TabulateFusionSeRGradOp, and TabulateFusionSeTTebdGradOp, each with forward_t calling the matching *GradForward kernel and backward_t calling the matching *GradGradForward kernel. The three parent Ops' backward_t bodies are updated to delegate via apply(...).
Second-order backward tests for all five ops
source/tests/pt/test_tabulate_fusion_se_a.py, source/tests/pt/test_tabulate_fusion_se_atten.py, source/tests/pt/test_tabulate_fusion_se_r.py, source/tests/pt/test_tabulate_fusion_se_t.py, source/tests/pt/test_tabulate_fusion_se_t_tebd.py
Adds test_second_order_backward to each test class. Each test chains two torch.autograd.grad calls with create_graph=True and asserts the second-order gradient is non-None, shape-matched to the descriptor, finite, and has nonzero maximum absolute value.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately summarizes the main objective: enabling second-order autograd support for tabulate descriptors by implementing missing autograd wrappers.
Linked Issues check ✅ Passed The PR fully addresses issue #4994 by implementing the missing autograd wrappers (TabulateFusionSeTGradOp and others) with backward methods that call grad-grad kernels, and adds comprehensive second-order backward tests.
Out of Scope Changes check ✅ Passed All changes are scoped to implementing second-order autograd support: new autograd wrapper classes, corresponding grad-grad operations, and second-order backward tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/op/pt/tabulate_multi_device.cc (1)

607-608: ⚡ Quick win

Remove unused private member variable.

The device member is declared but never used anywhere in TabulateFusionSeAGradOp. This appears to be leftover from an earlier implementation.

 class TabulateFusionSeAGradOp
     : public torch::autograd::Function<TabulateFusionSeAGradOp> {
- private:
-  std::string device;
-
  public:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/op/pt/tabulate_multi_device.cc` around lines 607 - 608, Remove the
unused private member variable `device` from the `TabulateFusionSeAGradOp` class
since it is declared but never referenced anywhere in the implementation. Simply
delete the line containing `std::string device;` from the private section.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/op/pt/tabulate_multi_device.cc`:
- Around line 607-608: Remove the unused private member variable `device` from
the `TabulateFusionSeAGradOp` class since it is declared but never referenced
anywhere in the implementation. Simply delete the line containing `std::string
device;` from the private section.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 10e5cae3-359e-4d0a-b565-f7fc90453f06

📥 Commits

Reviewing files that changed from the base of the PR and between 87d8557 and de7ebac.

📒 Files selected for processing (6)
  • source/op/pt/tabulate_multi_device.cc
  • source/tests/pt/test_tabulate_fusion_se_a.py
  • source/tests/pt/test_tabulate_fusion_se_atten.py
  • source/tests/pt/test_tabulate_fusion_se_r.py
  • source/tests/pt/test_tabulate_fusion_se_t.py
  • source/tests/pt/test_tabulate_fusion_se_t_tebd.py

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 87.07865% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.20%. Comparing base (87d8557) to head (de7ebac).

Files with missing lines Patch % Lines
source/op/pt/tabulate_multi_device.cc 87.07% 18 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5537      +/-   ##
==========================================
+ Coverage   82.18%   82.20%   +0.02%     
==========================================
  Files         890      890              
  Lines      101358   101616     +258     
  Branches     4240     4266      +26     
==========================================
+ Hits        83301    83534     +233     
- Misses      16756    16760       +4     
- Partials     1301     1322      +21     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested review from OutisLi and wanghan-iapcm June 15, 2026 04:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PyTorch Backend: Missing autograd wrapper for se_t descriptor's second-order derivatives

1 participant