Fix 3D decomposed relative position embedding for anisotropic inputs#8890
Fix 3D decomposed relative position embedding for anisotropic inputs#8890aymuos15 wants to merge 1 commit into
Conversation
The 3D branch of add_decomposed_rel_pos was broken for non-cubic spatial
sizes in two ways:
1. rel_d contracted the depth-axis embedding against the width axis
('bhwdc,wkc->bhwdk' instead of 'bhwdc,dkc->bhwdk').
2. The broadcast-add into the (b, qh, qw, qd, kh, kw, kd) attention tensor
used 4 leading colons instead of 5 for each rel_* term, so the key axis
landed in the wrong slot (e.g. rel_h's kh fell into the kd position).
Both were masked by the only 3D test using a cubic input_size (8, 8, 8):
with kh == kw == kd every misplaced axis still broadcasts. For anisotropic
volumes the forward pass raised a shape error (or, where conformable, added
the wrong positional bias) in SABlock/CrossAttentionBlock with
rel_pos_embedding='decomposed'.
Adds a regression test parametrized over several anisotropic 3D input
sizes, including permutations of distinct dims so each spatial axis is
exercised in a different position.
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
📝 WalkthroughWalkthroughThis PR fixes the depth-axis relative position embedding computation in 3D attention. The Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/networks/blocks/test_selfattention.py (1)
251-251:⚠️ Potential issue | 🟠 Major | ⚡ Quick winPre-existing bug:
use_combined_linearis undefined.This variable is not defined in
test_no_extra_weights_if_no_fc- it's only a loop variable intest_script. Test will raiseNameError.Proposed fix
`@parameterized.expand`([[True], [False]]) - def test_no_extra_weights_if_no_fc(self, include_fc): + def test_no_extra_weights_if_no_fc(self, include_fc, use_combined_linear=True): input_param = {Or update
@parameterized.expandto includeuse_combined_linearvalues.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/blocks/test_selfattention.py` at line 251, The test fails because use_combined_linear is referenced but not defined in test_no_extra_weights_if_no_fc; fix by explicitly defining use_combined_linear (e.g., insert use_combined_linear = False) inside test_no_extra_weights_if_no_fc before building the kwargs dictionary that includes "use_combined_linear", or alternatively update the `@parameterized.expand` on test_no_extra_weights_if_no_fc to include the use_combined_linear values (matching how test_script parametrizes it) so the symbol is provided to the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tests/networks/blocks/test_selfattention.py`:
- Line 251: The test fails because use_combined_linear is referenced but not
defined in test_no_extra_weights_if_no_fc; fix by explicitly defining
use_combined_linear (e.g., insert use_combined_linear = False) inside
test_no_extra_weights_if_no_fc before building the kwargs dictionary that
includes "use_combined_linear", or alternatively update the
`@parameterized.expand` on test_no_extra_weights_if_no_fc to include the
use_combined_linear values (matching how test_script parametrizes it) so the
symbol is provided to the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e35d751b-ce63-42f8-b2a3-201db9d1633f
📒 Files selected for processing (2)
monai/networks/blocks/attention_utils.pytests/networks/blocks/test_selfattention.py
The 3D branch of add_decomposed_rel_pos was broken for non-cubic spatial sizes in two ways:
Both were masked by the only 3D test using a cubic input_size (8, 8, 8): with kh == kw == kd every misplaced axis still broadcasts. For anisotropic volumes the forward pass raised a shape error (or, where conformable, added the wrong positional bias) in SABlock/CrossAttentionBlock with rel_pos_embedding='decomposed'.
Adds a regression test parametrized over several anisotropic 3D input sizes, including permutations of distinct dims so each spatial axis is exercised in a different position.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.