[JAX] Fallback to old triton ffi for autotuned kernels#3077
[JAX] Fallback to old triton ffi for autotuned kernels#3077jberchtold-nvidia wants to merge 2 commits into
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR fixes a CUDA IMA (Illegal Memory Access) issue by preventing autotuned kernels from using the newer
Confidence Score: 5/5Safe to merge — the change is narrowly scoped to routing autotuned kernels away from the new FFI target, and all other dispatch paths are unchanged. The boolean flag cleanly separates the autotuned and non-autotuned code paths with no risk of misclassification. Non-autotuned kernels and the JAX-version-based compatibility fallback both leave No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[triton_call_lowering] --> B{isinstance kernel_fn\nAutotuner?}
B -- No --> C[is_autotuned = False\nused_autotuned_launch = False]
B -- Yes --> D{is_triton_autotuned\n_alias_safe?}
D -- No --> E[Compatibility fallback:\nis_autotuned = False\nused_autotuned_launch = False]
D -- Yes --> F[is_autotuned = True]
F --> G[Build TritonAutotunedKernelCall\nfor all configs]
G --> H[used_autotuned_launch = True]
C --> I{FFI selection}
E --> I
H --> I
I --> J{not used_autotuned_launch\nAND jax_version_meets\nCUDA_GRAPH_MIN?}
J -- Yes --> K[triton_kernel_call_ffi\nnew FFI - CUDA graph support]
J -- No --> L[triton_kernel_call\nlegacy FFI - avoids CUDA IMA]
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| if ( | ||
| not used_autotuned_launch | ||
| and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION) | ||
| ): |
There was a problem hiding this comment.
The new branch condition does not carry a comment explaining why autotuned launches must skip the new FFI. Without it, a future reader will only see the mechanism (the flag) but not the root cause (CUDA IMA with
triton_kernel_call_ffi on autotuned kernels), making it easy to inadvertently remove the guard when refactoring.
| if ( | |
| not used_autotuned_launch | |
| and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION) | |
| ): | |
| # Autotuned kernels must use the older "triton_kernel_call" FFI: the newer | |
| # "triton_kernel_call_ffi" path triggers CUDA IMA (Illegal Memory Access) | |
| # errors for autotuned kernels and must be bypassed until the upstream issue | |
| # is resolved. | |
| if ( | |
| not used_autotuned_launch | |
| and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION) | |
| ): |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
for more information, see https://pre-commit.ci
|
/te-ci jax |
Description
Disables new "triton_kernel_call_ffi" and falls back to old "triton_kernel_call" ffi due to CUDA IMA issues observed with autotuned kernels on new "triton_kernel_call_ffi"
Type of change
Changes
Checklist: