From 59ce5807415f9f14e1c09d84cf79573f534d5e2e Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 1 Jun 2026 18:58:10 +0100 Subject: [PATCH] Fix 3D decomposed relative position embedding for anisotropic inputs The 3D branch of add_decomposed_rel_pos was broken for non-cubic spatial sizes in two ways: 1. rel_d contracted the depth-axis embedding against the width axis ('bhwdc,wkc->bhwdk' instead of 'bhwdc,dkc->bhwdk'). 2. The broadcast-add into the (b, qh, qw, qd, kh, kw, kd) attention tensor used 4 leading colons instead of 5 for each rel_* term, so the key axis landed in the wrong slot (e.g. rel_h's kh fell into the kd position). Both were masked by the only 3D test using a cubic input_size (8, 8, 8): with kh == kw == kd every misplaced axis still broadcasts. For anisotropic volumes the forward pass raised a shape error (or, where conformable, added the wrong positional bias) in SABlock/CrossAttentionBlock with rel_pos_embedding='decomposed'. Adds a regression test parametrized over several anisotropic 3D input sizes, including permutations of distinct dims so each spatial axis is exercised in a different position. Signed-off-by: Soumya Snigdha Kundu --- monai/networks/blocks/attention_utils.py | 8 ++++---- tests/networks/blocks/test_selfattention.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py index a8dfcd7df3..caf2bcc155 100644 --- a/monai/networks/blocks/attention_utils.py +++ b/monai/networks/blocks/attention_utils.py @@ -114,13 +114,13 @@ def add_decomposed_rel_pos( r_q = q.reshape(batch, q_h, q_w, q_d, dim) rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) - rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + rel_d = torch.einsum("bhwdc,dkc->bhwdk", r_q, rd) attn = ( attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) - + rel_h[:, :, :, :, None, None] - + rel_w[:, :, :, None, :, None] - + rel_d[:, :, :, None, None, :] + + rel_h[:, :, :, :, :, None, None] + + rel_w[:, :, :, :, None, :, None] + + rel_d[:, :, :, :, None, None, :] ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) return attn diff --git a/tests/networks/blocks/test_selfattention.py b/tests/networks/blocks/test_selfattention.py index af52918612..c960e30e58 100644 --- a/tests/networks/blocks/test_selfattention.py +++ b/tests/networks/blocks/test_selfattention.py @@ -67,6 +67,24 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + # anisotropic 3D shapes with all-distinct, permuted dims: a cubic input_size hides + # the decomposed rel-pos axis handling because every misplaced axis still broadcasts. + @parameterized.expand([[(4, 8, 16)], [(16, 8, 4)], [(2, 3, 4)], [(4, 3, 2)], [(2, 4, 3)]]) + @skipUnless(has_einops, "Requires einops") + def test_decomposed_rel_pos_anisotropic_3d(self, input_size): + hidden_size = 120 + net = SABlock( + hidden_size=hidden_size, + num_heads=6, + dropout_rate=0.1, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + input_size=input_size, + ) + seq_len = input_size[0] * input_size[1] * input_size[2] + with eval_mode(net): + result = net(torch.randn(2, seq_len, hidden_size)) + self.assertEqual(result.shape, (2, seq_len, hidden_size)) + def test_rel_pos_embedding_with_flash_attention(self): with self.assertRaises(ValueError): SABlock(