Skip to content
Open
14 changes: 13 additions & 1 deletion cuda_core/cuda/core/_linker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@
#
# SPDX-License-Identifier: Apache-2.0

from libcpp.vector cimport vector

from cuda.bindings cimport cydriver

from ._resource_handles cimport NvJitLinkHandle, CuLinkHandle


cdef class Linker:
cdef:
NvJitLinkHandle _nvjitlink_handle
CuLinkHandle _culink_handle
# WARNING: the driver backend passes raw pointers from _drv_jit_values
# and _drv_log_bufs to cuLinkCreate. cuLinkDestroy may still dereference
# them, so _close_noexcept must reset _culink_handle before releasing
# these retainers. Do not bypass or weaken that teardown order.
vector[cydriver.CUjit_option] _drv_jit_keys
vector[void*] _drv_jit_values
bint _use_nvjitlink
object _drv_log_bufs # formatted_options list (driver); None for nvjitlink; cleared in link()
object _drv_log_bufs # formatted_options list (driver); None for nvjitlink
str _info_log # decoded log; None until link() or pre-link get_*_log()
str _error_log # decoded log; None until link() or pre-link get_*_log()
object _options # LinkerOptions
object __weakref__

cdef void _close_noexcept(self) noexcept
336 changes: 252 additions & 84 deletions cuda_core/cuda/core/_linker.pyx

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_program.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
self._linker = Linker(
ObjectCode._init(code_bytes, code_type), options=_translate_program_options(options)
)
self._backend = str(Linker.which_backend())
self._backend = str(self._linker.backend)

elif code_type == "nvvm":
_get_nvvm_module() # Validate NVVM availability
Expand Down
35 changes: 23 additions & 12 deletions cuda_core/cuda/core/utils/_program_cache/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,11 @@ def _nvrtc_version() -> tuple[int, int]:
def _linker_backend_and_version(use_driver: bool) -> tuple[str, str]:
"""Return ``(backend, version)`` for the linker used on PTX inputs.

``use_driver`` is the result of ``_decide_nvjitlink_or_driver()`` and
must be passed in so a single ``make_program_cache_key`` call shares
one probe across :meth:`_LinkerBackend.validate`,
:meth:`option_fingerprint`, and :meth:`hash_version_probe` (otherwise
a transient probe flap could write inconsistent fields into the same
key).
``use_driver`` is the cached result of the Linker's ``_choose_backend()``
decision and must be passed in so a single ``make_program_cache_key`` call
shares one probe across :meth:`_LinkerBackend.validate`,
:meth:`option_fingerprint`, and :meth:`hash_version_probe` (otherwise a
transient probe flap could write inconsistent fields into the same key).

Raises any underlying probe exception. ``make_program_cache_key`` catches
and mixes the exception's class name into the digest, so the same probe
Expand All @@ -187,9 +186,9 @@ def _linker_backend_and_version(use_driver: bool) -> tuple[str, str]:
working probe (``_probe_failed`` label vs. ``driver``/``nvrtc``/...).

nvJitLink version lookup goes through ``sys.modules`` first so we hit the
same module ``_decide_nvjitlink_or_driver()`` already loaded. That keeps
fingerprinting aligned with whichever ``cuda.bindings.nvjitlink`` import
path the linker actually uses.
same module the Linker probe already loaded. That keeps fingerprinting
aligned with whichever ``cuda.bindings.nvjitlink`` import path the linker
actually uses.
"""
import sys

Expand Down Expand Up @@ -483,9 +482,21 @@ def _decide_driver(self) -> bool | None:
"""
if self._cached_decision is _DECISION_UNSET:
try:
from cuda.core._linker import _decide_nvjitlink_or_driver

self._cached_decision = _decide_nvjitlink_or_driver()
from cuda.core import _linker

try:
driver_major = _linker.driver_version()[0]
except _linker.CUDAError:
driver_major = None
self._cached_decision = (
_linker._choose_backend(
driver_major,
_linker._probe_nvjitlink(),
inputs_have_ltoir=False,
lto_requested=False,
)
== "driver"
)
except Exception as exc:
self._cached_decision = None
self._cached_decision_exc = exc
Expand Down
187 changes: 169 additions & 18 deletions cuda_core/tests/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

import gc
import inspect

import pytest

from cuda.core import Device, Linker, LinkerOptions, Program, ProgramOptions, _linker
from cuda.core._module import ObjectCode
from cuda.core._program import _can_load_generated_ptx
from cuda.core._utils.cuda_utils import CUDAError
from cuda.core._utils.version import driver_version

ARCH = "sm_" + "".join(f"{i}" for i in Device().compute_capability)

Expand All @@ -20,7 +23,22 @@
device_function_b = "__device__ int B() { return 0; }"
device_function_c = "__device__ int C(int a, int b) { return a + b; }"

is_culink_backend = _linker._decide_nvjitlink_or_driver()

def _current_env_backend() -> str:
"""Return the backend a default (PTX input, no LTO) Linker picks on this machine."""
try:
drv_major = driver_version()[0]
except Exception:
drv_major = None
return _linker._choose_backend(
drv_major,
_linker._probe_nvjitlink(),
inputs_have_ltoir=False,
lto_requested=False,
)


is_culink_backend = _current_env_backend() == "driver"
if not is_culink_backend:
from cuda.bindings import nvjitlink

Expand Down Expand Up @@ -94,11 +112,15 @@ def test_linker_init(compile_ptx_functions, options):
linker = Linker(*compile_ptx_functions, options=options)
object_code = linker.link("cubin")
assert isinstance(object_code, ObjectCode)
assert Linker.which_backend() == ("driver" if is_culink_backend else "nvJitLink")
assert linker.backend == ("driver" if is_culink_backend else "nvJitLink")


def test_linker_init_invalid_arch(compile_ptx_functions):
err = AttributeError if is_culink_backend else nvjitlink.nvJitLinkError
# With the driver backend, ptx=True (which implies link-time optimization)
# cannot be satisfied at all, so dispatch raises RuntimeError before the
# arch string is ever parsed. With the nvJitLink backend, the arch string
# is validated by nvJitLink itself.
err = RuntimeError if is_culink_backend else nvjitlink.nvJitLinkError
with pytest.raises(err):
options = LinkerOptions(arch="99", ptx=True)
Linker(*compile_ptx_functions, options=options)
Expand Down Expand Up @@ -207,7 +229,7 @@ def test_linker_options_as_bytes_invalid_backend():
def test_linker_options_as_bytes_driver_not_supported():
"""Test that as_bytes() is not supported for driver backend"""
options = LinkerOptions(arch="sm_80")
with pytest.raises(RuntimeError, match="as_bytes\\(\\) only supports 'nvjitlink' backend"):
with pytest.raises(ValueError, match="as_bytes\\(\\) only supports 'nvjitlink' backend"):
options.as_bytes("driver")


Expand All @@ -234,6 +256,69 @@ def test_linker_handle(compile_ptx_functions):
assert int(handle) != 0


def test_driver_linker_lifetime_no_heap_corruption(monkeypatch, compile_ptx_functions):
"""Driver-backend teardown must not leave cuLinkCreate option arrays or log buffers dangling.

Two prior bugs corrupted the heap during driver-linker teardown: the log
buffer bytearrays were cleared before cuLinkDestroy ran, and the
optionValues array was a stack-local vector destroyed when Linker_init
returned. Both manifested in the NEXT CUDA operation after the Linker was
destroyed, not at destruction itself. This test forces the driver backend,
destroys one Linker via __dealloc__, and explicitly closes another.
"""
if not _can_load_generated_ptx():
pytest.skip("PTX version too new for current driver")
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: None)

linker = Linker(*compile_ptx_functions, options=LinkerOptions(arch=ARCH))
assert linker.backend == "driver"
linker.link("cubin")
del linker
gc.collect()

obj_a = Program(kernel_a, "c++", ProgramOptions(relocatable_device_code=True)).compile("ptx")
obj_b = Program(device_function_b, "c++", ProgramOptions(relocatable_device_code=True)).compile("ptx")
obj_c = Program(device_function_c, "c++", ProgramOptions(relocatable_device_code=True)).compile("ptx")
linker2 = Linker(obj_a, obj_b, obj_c, options=LinkerOptions(arch=ARCH))
assert linker2.backend == "driver"
linker2.link("cubin")
linker2.close()
del linker2


def test_driver_linker_get_error_log_after_close_on_failed_link(init_cuda, monkeypatch):
"""close() must preserve get_error_log() output when link() failed.

link() only caches _info_log/_error_log on the success path, so after
a failed cuLinkComplete the driver log buffers are the only source of
the error diagnostic. close() releases those buffers, and callers
should still be able to read the captured error log afterward.
"""
if not _can_load_generated_ptx():
pytest.skip("PTX version too new for current driver")
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: None)

bad_kernel = """
extern __device__ int Z();
__global__ void A() { int r = Z(); }
"""
bad_obj = Program(bad_kernel, "c++", ProgramOptions(relocatable_device_code=True)).compile("ptx")
linker = Linker(bad_obj, options=LinkerOptions(arch=ARCH))
assert linker.backend == "driver"
with pytest.raises(CUDAError):
linker.link("cubin")

pre_close_err = linker.get_error_log()
assert isinstance(pre_close_err, str)
assert pre_close_err # failed link must have produced a diagnostic

linker.close()
# close() releases the raw driver buffers; the cached decoded logs must
# still be readable.
assert linker.get_error_log() == pre_close_err
assert isinstance(linker.get_info_log(), str)


@pytest.mark.skipif(is_culink_backend, reason="nvjitlink options only tested with nvjitlink backend")
def test_linker_options_nvjitlink_options_as_str():
"""_prepare_nvjitlink_options(as_bytes=False) returns plain strings."""
Expand All @@ -246,37 +331,103 @@ def test_linker_options_nvjitlink_options_as_str():
assert "-lineinfo" in options


# ---------------------------------------------------------------------------
# Per-instance dispatch tests
#
# The full _choose_backend() decision matrix lives in test_linker_dispatch.py as
# GPU-free unit tests. The tests below drive the same dispatch logic through the
# real Linker constructor (with patched version probes) to confirm that the
# dispatch is invoked before any backend handle is created.
# ---------------------------------------------------------------------------


class TestLinkerDispatch:
"""Per-instance dispatch exercised by constructing a Linker with patched version probes.

These tests intercept both :func:`driver_version` (via the name imported into
``_linker``) and :func:`_probe_nvjitlink` so the decision is deterministic,
then assert that ``Linker.__init__`` raises before creating any backend handle
for the unsatisfiable cases.
"""

@pytest.fixture
def ltoir_object(self):
# A minimal ObjectCode marked as ltoir is sufficient: _choose_backend runs
# before any backend handle is created, so the payload never reaches the
# linker libraries.
return ObjectCode._init(b"\x00stub-ltoir-payload", "ltoir")

def test_ltoir_without_nvjitlink_raises(self, monkeypatch, ltoir_object):
monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: None)
with pytest.raises(RuntimeError, match="nvJitLink is not available"):
Linker(ltoir_object, options=LinkerOptions(arch=ARCH))

def test_cross_major_with_ltoir_raises(self, monkeypatch, ltoir_object):
monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: (12, 9))
with pytest.raises(RuntimeError, match="matching major versions"):
Linker(ltoir_object, options=LinkerOptions(arch=ARCH))

@pytest.fixture
def ptx_object(self):
# Stub PTX payload; dispatch raises before the bytes reach any backend.
return ObjectCode._init(b"// stub ptx\n", "ptx")

def test_cross_major_with_lto_option_raises(self, monkeypatch, ptx_object):
monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: (13, 0))
with pytest.raises(RuntimeError, match="matching major versions"):
Linker(
ptx_object,
options=LinkerOptions(arch=ARCH, link_time_optimization=True),
)

def test_lto_without_nvjitlink_raises(self, monkeypatch, ptx_object):
monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: None)
with pytest.raises(RuntimeError, match="nvJitLink is not available"):
Linker(
ptx_object,
options=LinkerOptions(arch=ARCH, link_time_optimization=True),
)


class TestWhichBackendClassmethod:
def test_which_backend_returns_nvjitlink(self, monkeypatch):
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", True)
assert Linker.which_backend() == "nvJitLink"
monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: (12, 9))
with pytest.warns(DeprecationWarning, match="linker.backend"):
assert Linker.which_backend() == "nvJitLink"

def test_which_backend_returns_driver(self, monkeypatch):
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", False)
assert Linker.which_backend() == "driver"
monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: (12, 9))
with pytest.warns(DeprecationWarning, match="linker.backend"):
assert Linker.which_backend() == "driver"

def test_which_backend_invokes_probe_when_not_memoised(self, monkeypatch):
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", None)
def test_which_backend_invokes_probe(self, monkeypatch):
called = []

def fake_decide():
def fake_probe():
called.append(True)
return False # False = not falling back to driver = nvJitLink
return (12, 9)

monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", fake_decide)
result = Linker.which_backend()
monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0))
monkeypatch.setattr(_linker, "_probe_nvjitlink", fake_probe)
with pytest.warns(DeprecationWarning, match="linker.backend"):
result = Linker.which_backend()
assert result == "nvJitLink"
assert called, "_decide_nvjitlink_or_driver was not called"
assert called, "_probe_nvjitlink was not called"

def test_which_backend_is_classmethod(self):
attr = inspect.getattr_static(Linker, "which_backend")
assert isinstance(attr, classmethod)

def test_which_backend_is_not_property(self):
"""which_backend is a classmethod, not a property.
"""which_backend remains a compatibility classmethod, not a property.

This is an intentional breaking change from the prior ``backend`` property API.
All call sites must use parens: ``Linker.which_backend()``.
It is deprecated in favor of the per-instance ``linker.backend`` property.
"""
attr = inspect.getattr_static(Linker, "which_backend")
assert not isinstance(attr, property)
Loading