Skip to content

[JAX] Fallback to old triton ffi for autotuned kernels#3077

Open
jberchtold-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/disable-new-triton-ffi-for-autotuned-kernels
Open

[JAX] Fallback to old triton ffi for autotuned kernels#3077
jberchtold-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/disable-new-triton-ffi-for-autotuned-kernels

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

Disables new "triton_kernel_call_ffi" and falls back to old "triton_kernel_call" ffi due to CUDA IMA issues observed with autotuned kernels on new "triton_kernel_call_ffi"

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Fallback autotuned kernels to "triton_kernel_call"

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 2, 2026

Greptile Summary

This PR fixes a CUDA IMA (Illegal Memory Access) issue by preventing autotuned kernels from using the newer triton_kernel_call_ffi custom call target, falling back to the older triton_kernel_call FFI instead.

  • Introduces a used_autotuned_launch boolean flag in triton_call_lowering; the flag is set to True only when a TritonAutotunedKernelCall is built, and gates the FFI selection so autotuned paths always use the legacy triton_kernel_call target.
  • Non-autotuned kernels and the pre-existing compatibility-fallback path (old JAX, is_triton_autotuned_alias_safe() returns False) are unaffected and continue to reach triton_kernel_call_ffi when the JAX version requirement is met.

Confidence Score: 5/5

Safe to merge — the change is narrowly scoped to routing autotuned kernels away from the new FFI target, and all other dispatch paths are unchanged.

The boolean flag cleanly separates the autotuned and non-autotuned code paths with no risk of misclassification. Non-autotuned kernels and the JAX-version-based compatibility fallback both leave used_autotuned_launch as False, preserving their existing behavior. The only trade-off is that autotuned kernels no longer benefit from CUDA graph support via triton_kernel_call_ffi, but that is the explicit intent of the fix given the CUDA IMA bug.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/utils.py Adds used_autotuned_launch flag to route autotuned kernels away from triton_kernel_call_ffi and onto the legacy triton_kernel_call target; logic is correct and all non-autotuned paths are unaffected.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[triton_call_lowering] --> B{isinstance kernel_fn\nAutotuner?}
    B -- No --> C[is_autotuned = False\nused_autotuned_launch = False]
    B -- Yes --> D{is_triton_autotuned\n_alias_safe?}
    D -- No --> E[Compatibility fallback:\nis_autotuned = False\nused_autotuned_launch = False]
    D -- Yes --> F[is_autotuned = True]
    F --> G[Build TritonAutotunedKernelCall\nfor all configs]
    G --> H[used_autotuned_launch = True]
    C --> I{FFI selection}
    E --> I
    H --> I
    I --> J{not used_autotuned_launch\nAND jax_version_meets\nCUDA_GRAPH_MIN?}
    J -- Yes --> K[triton_kernel_call_ffi\nnew FFI - CUDA graph support]
    J -- No --> L[triton_kernel_call\nlegacy FFI - avoids CUDA IMA]
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +635 to +638
if (
not used_autotuned_launch
and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION)
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The new branch condition does not carry a comment explaining why autotuned launches must skip the new FFI. Without it, a future reader will only see the mechanism (the flag) but not the root cause (CUDA IMA with triton_kernel_call_ffi on autotuned kernels), making it easy to inadvertently remove the guard when refactoring.

Suggested change
if (
not used_autotuned_launch
and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION)
):
# Autotuned kernels must use the older "triton_kernel_call" FFI: the newer
# "triton_kernel_call_ffi" path triggers CUDA IMA (Illegal Memory Access)
# errors for autotuned kernels and must be bypassed until the upstream issue
# is resolved.
if (
not used_autotuned_launch
and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION)
):

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

Copy link
Copy Markdown
Collaborator

@tdophung tdophung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants