diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index cdef7c3be6..96ce859948 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -126,6 +126,118 @@ def _process_define_macro(formatted_options, macro): raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") +def _format_options_for_backend(options_dict: dict, backend: str) -> list[str]: + """Format compilation options for a specific backend. + + This helper function converts a dictionary of option names and values into + properly formatted string options for the specified backend. Different backends + (NVRTC, NVVM, nvJitLink) use slightly different option naming conventions and + value formats. + + Parameters + ---------- + options_dict : dict + Dictionary mapping option names to their values. The keys should be + generic option names (e.g., "arch", "debug", "ftz"). + backend : str + The backend to format options for. Must be one of "NVRTC", "NVVM", or "nvJitLink". + + Returns + ------- + list[str] + List of formatted option strings suitable for the specified backend. + + Raises + ------ + ValueError + If an unsupported backend is specified. + + Notes + ----- + - NVRTC uses `--` prefix and "true"/"false" for booleans + - NVVM uses `-` prefix and "1"/"0" for booleans + - nvJitLink uses `-` prefix and "true"/"false" for booleans + """ + if backend not in ("NVRTC", "NVVM", "nvJitLink"): + raise ValueError(f"Unsupported backend '{backend}'. Must be one of: NVRTC, NVVM, nvJitLink") + + formatted = [] + + for key, value in options_dict.items(): + if value is None: + continue + + if backend == "NVRTC": + # NVRTC uses -- prefix + if key == "arch": + formatted.append(f"-arch={value}") + elif key == "debug" and value: + formatted.append("--device-debug") + elif key == "lineinfo" and value: + formatted.append("--generate-line-info") + elif key == "max_register_count": + formatted.append(f"--maxrregcount={value}") + elif key in ("ftz", "prec_sqrt", "prec_div"): + bool_val = "true" if value else "false" + # NVRTC uses hyphens in option names + option_name = key.replace("_", "-") + formatted.append(f"--{option_name}={bool_val}") + elif key == "fma": + bool_val = "true" if value else "false" + formatted.append(f"--fmad={bool_val}") + elif key == "device_code_optimize" and value: + formatted.append("--dopt=on") + elif key == "use_fast_math" and value: + formatted.append("--use_fast_math") + elif key == "link_time_optimization" and value: + formatted.append("--dlink-time-opt") + # Add more NVRTC-specific options as needed + + elif backend == "NVVM": + # NVVM uses - prefix and 1/0 for booleans + if key == "arch": + # NVVM uses compute_ instead of sm_ + arch_val = value + if arch_val.startswith("sm_"): + arch_val = f"compute_{arch_val[3:]}" + formatted.append(f"-arch={arch_val}") + elif key == "debug" and value: + formatted.append("-g") + elif key == "device_code_optimize": + # NVVM explicitly handles both True and False + if value is False: + formatted.append("-opt=0") + elif value is True: + formatted.append("-opt=3") + elif key in ("ftz", "prec_sqrt", "prec_div", "fma"): + bool_val = "1" if value else "0" + # NVVM uses hyphens in option names + option_name = key.replace("_", "-") + formatted.append(f"-{option_name}={bool_val}") + # lineinfo and link_time_optimization are not supported by NVVM, skip them + + elif backend == "nvJitLink": + # nvJitLink uses - prefix and true/false for booleans + if key == "arch": + formatted.append(f"-arch={value}") + elif key == "debug" and value: + formatted.append("-g") + elif key == "lineinfo" and value: + formatted.append("-lineinfo") + elif key == "max_register_count": + formatted.append(f"-maxrregcount={value}") + elif key in ("ftz", "prec_sqrt", "prec_div", "fma"): + bool_val = "true" if value else "false" + # nvJitLink uses hyphens in option names + option_name = key.replace("_", "-") + formatted.append(f"-{option_name}={bool_val}") + elif key == "link_time_optimization" and value: + formatted.append("-lto") + # device_code_optimize is not supported by nvJitLink, skip it + + return formatted + + @dataclass class ProgramOptions: """Customizable options for configuring `Program`. @@ -422,9 +534,91 @@ def __post_init__(self): if self.numba_debug: self._formatted_options.append("--numba-debug") - def _as_bytes(self): - # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved - return list(o.encode() for o in self._formatted_options) + def as_bytes(self, backend: str = "NVRTC") -> list[bytes]: + """Convert the formatted program options to a list of byte strings. + + This method encodes the options stored in this `ProgramOptions` instance + into byte strings formatted for the specified backend, suitable for passing + to C libraries that calls the underlying compiler library. + + Parameters + ---------- + backend : str, optional + The compiler backend to format options for. Must be one of: + + - "NVRTC" (default): NVIDIA NVRTC compiler, supports all ProgramOptions + - "NVVM": NVIDIA NVVM compiler, supports a subset of options + - "nvJitLink": NVIDIA nvJitLink linker, supports a subset of options + + Different backends use different option naming conventions and support + different option subsets. This method will format and filter options + appropriately for the chosen backend. + + Returns + ------- + list[bytes] + A list of byte-encoded option strings. Each element represents + a single compilation option in the format expected by the underlying compiler library. + + Raises + ------ + ValueError + If an unsupported backend is specified. + + Examples + -------- + >>> options = ProgramOptions(arch="sm_80", debug=True) + >>> # Get options for NVRTC (default) + >>> nvrtc_options = options.as_bytes() + >>> print(nvrtc_options) + [b'-arch=sm_80', b'--device-debug'] + >>> + >>> # Get options for NVVM + >>> nvvm_options = options.as_bytes("NVVM") + >>> print(nvvm_options) + [b'-arch=compute_80', b'-g'] + >>> + >>> # Get options for nvJitLink + >>> nvjitlink_options = options.as_bytes("nvJitLink") + >>> print(nvjitlink_options) + [b'-arch=sm_80', b'-g'] + """ + if backend == "NVRTC": + # For NVRTC, use the pre-formatted options (backward compatible) + # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + return list(o.encode() for o in self._formatted_options) + + elif backend in ("NVVM", "nvJitLink"): + # For NVVM and nvJitLink, extract common options and format appropriately + options_dict = {} + + # Common options supported by multiple backends + if self.arch is not None: + options_dict["arch"] = self.arch + if self.debug is not None: + options_dict["debug"] = self.debug + if self.lineinfo is not None: + options_dict["lineinfo"] = self.lineinfo + if self.max_register_count is not None: + options_dict["max_register_count"] = self.max_register_count + if self.ftz is not None: + options_dict["ftz"] = self.ftz + if self.prec_sqrt is not None: + options_dict["prec_sqrt"] = self.prec_sqrt + if self.prec_div is not None: + options_dict["prec_div"] = self.prec_div + if self.fma is not None: + options_dict["fma"] = self.fma + if self.device_code_optimize is not None: + options_dict["device_code_optimize"] = self.device_code_optimize + if self.link_time_optimization is not None: + options_dict["link_time_optimization"] = self.link_time_optimization + + formatted_options = _format_options_for_backend(options_dict, backend) + return list(o.encode() for o in formatted_options) + + else: + raise ValueError(f"Unsupported backend '{backend}'. Must be one of: NVRTC, NVVM, nvJitLink") def __repr__(self): # __TODO__ improve this @@ -609,7 +803,7 @@ def compile(self, target_type, name_expressions=(), logs=None): nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle, ) - options = self._options._as_bytes() + options = self._options.as_bytes() handle_return( nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), handle=self._mnff.handle, diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 8a6526fcc2..1388d472d3 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -411,3 +411,109 @@ def test_nvvm_program_options(init_cuda, nvvm_ir, options): assert ".visible .entry simple(" in ptx_text program.close() + + +def test_program_options_as_bytes(): + """Test that ProgramOptions.as_bytes() returns correct byte strings""" + # Test with various options + options = ProgramOptions( + arch="sm_80", + debug=True, + lineinfo=True, + max_register_count=32, + ftz=True, + use_fast_math=True, + ) + + byte_options = options.as_bytes() + + # Verify the return type + assert isinstance(byte_options, list) + assert all(isinstance(opt, bytes) for opt in byte_options) + + # Verify specific options are present in byte format + assert b"-arch=sm_80" in byte_options + assert b"--device-debug" in byte_options + assert b"--generate-line-info" in byte_options + assert b"--maxrregcount=32" in byte_options + assert b"--ftz=true" in byte_options + assert b"--use_fast_math" in byte_options + + +def test_program_options_as_bytes_empty(): + """Test that ProgramOptions.as_bytes() works with minimal options""" + # Test with minimal options (only defaults) + options = ProgramOptions() + + byte_options = options.as_bytes() + + # Should at least have arch option (automatically set based on Device if not provided) + assert isinstance(byte_options, list) + assert len(byte_options) > 0 + assert all(isinstance(opt, bytes) for opt in byte_options) + # The arch option should be present (automatically determined from current device) + assert any(b"-arch=" in opt for opt in byte_options) + + +def test_program_options_as_bytes_nvvm_backend(): + """Test that ProgramOptions.as_bytes() formats options correctly for NVVM backend""" + options = ProgramOptions( + arch="sm_80", + debug=True, + ftz=True, + prec_sqrt=False, + prec_div=True, + fma=False, + device_code_optimize=True, + ) + + byte_options = options.as_bytes("NVVM") + + # Verify the return type + assert isinstance(byte_options, list) + assert all(isinstance(opt, bytes) for opt in byte_options) + + # NVVM uses compute_ instead of sm_ and 1/0 for booleans, with hyphens in option names + assert b"-arch=compute_80" in byte_options + assert b"-g" in byte_options + assert b"-ftz=1" in byte_options + assert b"-prec-sqrt=0" in byte_options + assert b"-prec-div=1" in byte_options + assert b"-fma=0" in byte_options + assert b"-opt=3" in byte_options + + +def test_program_options_as_bytes_nvjitlink_backend(): + """Test that ProgramOptions.as_bytes() formats options correctly for nvJitLink backend""" + options = ProgramOptions( + arch="sm_80", + debug=True, + lineinfo=True, + max_register_count=32, + ftz=False, + prec_sqrt=True, + link_time_optimization=True, + ) + + byte_options = options.as_bytes("nvJitLink") + + # Verify the return type + assert isinstance(byte_options, list) + assert all(isinstance(opt, bytes) for opt in byte_options) + + # nvJitLink uses - prefix and true/false for booleans, with hyphens in option names + assert b"-arch=sm_80" in byte_options + assert b"-g" in byte_options + assert b"-lineinfo" in byte_options + assert b"-maxrregcount=32" in byte_options + assert b"-ftz=false" in byte_options + assert b"-prec-sqrt=true" in byte_options + assert b"-lto" in byte_options + + +def test_program_options_as_bytes_invalid_backend(): + """Test that ProgramOptions.as_bytes() raises error for invalid backend""" + options = ProgramOptions() + + with pytest.raises(ValueError, match="Unsupported backend 'invalid'"): + options.as_bytes("invalid")