diff --git a/cuda_core/cuda/core/experimental/_linker.py b/cuda_core/cuda/core/experimental/_linker.py index cef778c9aa..eb850b5973 100644 --- a/cuda_core/cuda/core/experimental/_linker.py +++ b/cuda_core/cuda/core/experimental/_linker.py @@ -395,10 +395,9 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None): def _add_code_object(self, object_code: ObjectCode): data = object_code._module - assert_type(data, bytes) with _exception_manager(self): name_str = f"{object_code.name}" - if _nvjitlink: + if _nvjitlink and isinstance(data, bytes): _nvjitlink.add_data( self._mnff.handle, self._input_type_from_code_type(object_code._code_type), @@ -406,7 +405,13 @@ def _add_code_object(self, object_code: ObjectCode): len(data), name_str, ) - else: + elif _nvjitlink and isinstance(data, str): + _nvjitlink.add_file( + self._mnff.handle, + self._input_type_from_code_type(object_code._code_type), + data, + ) + elif (not _nvjitlink) and isinstance(data, bytes): name_bytes = name_str.encode() handle_return( _driver.cuLinkAddData( @@ -421,6 +426,21 @@ def _add_code_object(self, object_code: ObjectCode): ) ) self._mnff.const_char_keep_alive.append(name_bytes) + elif (not _nvjitlink) and isinstance(data, str): + name_bytes = name_str.encode() + handle_return( + _driver.cuLinkAddFile( + self._mnff.handle, + self._input_type_from_code_type(object_code._code_type), + data.encode(), + 0, + None, + None, + ) + ) + self._mnff.const_char_keep_alive.append(name_bytes) + else: + raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") def link(self, target_type) -> ObjectCode: """ diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index 71293be4d1..2c7ea3a156 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -666,6 +666,11 @@ def name(self) -> str: """Return a human-readable name of this code object.""" return self._name + @property + def code_type(self) -> str: + """Return the type of the underlying code object.""" + return self._code_type + @property @precondition(_lazy_load_module) def handle(self): diff --git a/cuda_core/docs/source/release/0.X.Y-notes.rst b/cuda_core/docs/source/release/0.X.Y-notes.rst index e87cbdee31..2fb4093214 100644 --- a/cuda_core/docs/source/release/0.X.Y-notes.rst +++ b/cuda_core/docs/source/release/0.X.Y-notes.rst @@ -32,6 +32,7 @@ New features - CUDA 13.x testing support through new ``test-cu13`` dependency group. - Stream-ordered memory allocation can now be shared on Linux via :class:`DeviceMemoryResource`. - Added NVVM IR support to :class:`Program`. NVVM IR is now understood with ``code_type="nvvm"``. +- Added an :attr:`ObjectCode.code_type` attribute for querying the code type. New examples diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 1629e826a4..49df966c08 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -60,20 +60,20 @@ def test_object_code_init_disabled(): @pytest.fixture(scope="function") -def get_saxpy_kernel(init_cuda): +def get_saxpy_kernel_cubin(init_cuda): # prepare program prog = Program(SAXPY_KERNEL, code_type="c++") mod = prog.compile( "cubin", name_expressions=("saxpy", "saxpy"), ) - # run in single precision return mod.get_kernel("saxpy"), mod @pytest.fixture(scope="function") def get_saxpy_kernel_ptx(init_cuda): + # prepare program prog = Program(SAXPY_KERNEL, code_type="c++") mod = prog.compile( "ptx", @@ -84,12 +84,10 @@ def get_saxpy_kernel_ptx(init_cuda): @pytest.fixture(scope="function") -def get_saxpy_object_code(init_cuda): - prog = Program(SAXPY_KERNEL, code_type="c++") - mod = prog.compile( - "cubin", - name_expressions=("saxpy", "saxpy"), - ) +def get_saxpy_kernel_ltoir(init_cuda): + # Create LTOIR code using link-time optimization + prog = Program(SAXPY_KERNEL, code_type="c++", options=ProgramOptions(link_time_optimization=True)) + mod = prog.compile("ltoir", name_expressions=("saxpy", "saxpy")) return mod @@ -129,8 +127,8 @@ def test_get_kernel(init_cuda): ("cluster_scheduling_policy_preference", int), ], ) -def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type): - kernel, _ = get_saxpy_kernel +def test_read_only_kernel_attributes(get_saxpy_kernel_cubin, attr, expected_type): + kernel, _ = get_saxpy_kernel_cubin method = getattr(kernel.attributes, attr) # get the value without providing a device ordinal value = method() @@ -142,16 +140,6 @@ def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type): assert isinstance(value, expected_type), f"Expected {attr} to be of type {expected_type}, but got {type(value)}" -def test_object_code_load_cubin(get_saxpy_kernel): - _, mod = get_saxpy_kernel - cubin = mod._module - sym_map = mod._sym_map - assert isinstance(cubin, bytes) - mod = ObjectCode.from_cubin(cubin, symbol_mapping=sym_map) - assert mod.code == cubin - mod.get_kernel("saxpy") # force loading - - def test_object_code_load_ptx(get_saxpy_kernel_ptx): ptx, mod = get_saxpy_kernel_ptx sym_map = mod._sym_map @@ -162,8 +150,32 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx): mod_obj.get_kernel("saxpy") # force loading -def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path): - _, mod = get_saxpy_kernel +def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path): + ptx, mod = get_saxpy_kernel_ptx + sym_map = mod._sym_map + assert isinstance(ptx, bytes) + ptx_file = tmp_path / "test.ptx" + ptx_file.write_bytes(ptx) + mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map) + assert mod_obj.code == str(ptx_file) + assert mod_obj.code_type == "ptx" + if not Program._can_load_generated_ptx(): + pytest.skip("PTX version too new for current driver") + mod_obj.get_kernel("saxpy") # force loading + + +def test_object_code_load_cubin(get_saxpy_kernel_cubin): + _, mod = get_saxpy_kernel_cubin + cubin = mod._module + sym_map = mod._sym_map + assert isinstance(cubin, bytes) + mod = ObjectCode.from_cubin(cubin, symbol_mapping=sym_map) + assert mod.code == cubin + mod.get_kernel("saxpy") # force loading + + +def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path): + _, mod = get_saxpy_kernel_cubin cubin = mod._module sym_map = mod._sym_map assert isinstance(cubin, bytes) @@ -174,13 +186,42 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path): mod.get_kernel("saxpy") # force loading -def test_object_code_handle(get_saxpy_object_code): - mod = get_saxpy_object_code +def test_object_code_handle(get_saxpy_kernel_cubin): + _, mod = get_saxpy_kernel_cubin assert mod.handle is not None -def test_saxpy_arguments(get_saxpy_kernel, cuda12_4_prerequisite_check): - krn, _ = get_saxpy_kernel +def test_object_code_load_ltoir(get_saxpy_kernel_ltoir): + mod = get_saxpy_kernel_ltoir + ltoir = mod._module + sym_map = mod._sym_map + assert isinstance(ltoir, bytes) + mod_obj = ObjectCode.from_ltoir(ltoir, symbol_mapping=sym_map) + assert mod_obj.code == ltoir + assert mod_obj.code_type == "ltoir" + # ltoir doesn't support kernel retrieval directly as it's used for linking + assert mod_obj._handle is None + # Test that get_kernel fails for unsupported code type + with pytest.raises(RuntimeError, match=r'Unsupported code type "ltoir"'): + mod_obj.get_kernel("saxpy") + + +def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path): + mod = get_saxpy_kernel_ltoir + ltoir = mod._module + sym_map = mod._sym_map + assert isinstance(ltoir, bytes) + ltoir_file = tmp_path / "test.ltoir" + ltoir_file.write_bytes(ltoir) + mod_obj = ObjectCode.from_ltoir(str(ltoir_file), symbol_mapping=sym_map) + assert mod_obj.code == str(ltoir_file) + assert mod_obj.code_type == "ltoir" + # ltoir doesn't support kernel retrieval directly as it's used for linking + assert mod_obj._handle is None + + +def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check): + krn, _ = get_saxpy_kernel_cubin if cuda12_4_prerequisite_check: assert krn.num_arguments == 5 @@ -258,8 +299,8 @@ def test_num_args_error_handling(deinit_all_contexts_function, cuda12_4_prerequi @pytest.mark.parametrize("block_size", [32, 64, 96, 120, 128, 256]) @pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096]) -def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_size, smem_size_per_block): - kernel, _ = get_saxpy_kernel +def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel_cubin, block_size, smem_size_per_block): + kernel, _ = get_saxpy_kernel_cubin dev_props = Device().properties assert block_size <= dev_props.max_threads_per_block assert smem_size_per_block <= dev_props.max_shared_memory_per_block @@ -275,9 +316,9 @@ def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_s @pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 256, 0]) @pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096]) -def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel, block_size_limit, smem_size_per_block): +def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel_cubin, block_size_limit, smem_size_per_block): """Tests use case when shared memory needed is independent on the block size""" - kernel, _ = get_saxpy_kernel + kernel, _ = get_saxpy_kernel_cubin dev_props = Device().properties assert block_size_limit <= dev_props.max_threads_per_block assert smem_size_per_block <= dev_props.max_shared_memory_per_block @@ -302,9 +343,9 @@ def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel, block_siz @pytest.mark.skipif(numba is None, reason="Test requires numba to be installed") @pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 277, 0]) -def test_occupancy_max_potential_block_size_b2dsize(get_saxpy_kernel, block_size_limit): +def test_occupancy_max_potential_block_size_b2dsize(get_saxpy_kernel_cubin, block_size_limit): """Tests use case when shared memory needed depends on the block size""" - kernel, _ = get_saxpy_kernel + kernel, _ = get_saxpy_kernel_cubin def shared_memory_needed(block_size: numba.intc) -> numba.size_t: "Size of dynamic shared memory needed by kernel of this block size" @@ -329,8 +370,8 @@ def shared_memory_needed(block_size: numba.intc) -> numba.size_t: @pytest.mark.parametrize("num_blocks_per_sm, block_size", [(4, 32), (2, 64), (2, 96), (3, 120), (2, 128), (1, 256)]) -def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, num_blocks_per_sm, block_size): - kernel, _ = get_saxpy_kernel +def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel_cubin, num_blocks_per_sm, block_size): + kernel, _ = get_saxpy_kernel_cubin dev_props = Device().properties assert block_size <= dev_props.max_threads_per_block assert num_blocks_per_sm * block_size <= dev_props.max_threads_per_multiprocessor @@ -340,8 +381,8 @@ def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, n @pytest.mark.parametrize("cluster", [None, 2]) -def test_occupancy_max_active_clusters(get_saxpy_kernel, cluster): - kernel, _ = get_saxpy_kernel +def test_occupancy_max_active_clusters(get_saxpy_kernel_cubin, cluster): + kernel, _ = get_saxpy_kernel_cubin dev = Device() if dev.compute_capability < (9, 0): pytest.skip("Device with compute capability 90 or higher is required for cluster support") @@ -355,8 +396,8 @@ def test_occupancy_max_active_clusters(get_saxpy_kernel, cluster): assert max_active_clusters >= 0 -def test_occupancy_max_potential_cluster_size(get_saxpy_kernel): - kernel, _ = get_saxpy_kernel +def test_occupancy_max_potential_cluster_size(get_saxpy_kernel_cubin): + kernel, _ = get_saxpy_kernel_cubin dev = Device() if dev.compute_capability < (9, 0): pytest.skip("Device with compute capability 90 or higher is required for cluster support") @@ -370,11 +411,11 @@ def test_occupancy_max_potential_cluster_size(get_saxpy_kernel): assert max_potential_cluster_size >= 0 -def test_module_serialization_roundtrip(get_saxpy_kernel): - _, objcode = get_saxpy_kernel +def test_module_serialization_roundtrip(get_saxpy_kernel_cubin): + _, objcode = get_saxpy_kernel_cubin result = pickle.loads(pickle.dumps(objcode)) # noqa: S403, S301 assert isinstance(result, ObjectCode) assert objcode.code == result.code assert objcode._sym_map == result._sym_map - assert objcode._code_type == result._code_type + assert objcode.code_type == result.code_type