From 7ce48aa152d8025d59c543faae56eeb0521918a4 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Sep 2025 00:09:53 -0400 Subject: [PATCH 1/5] initial support --- ffi/examples/inline_module/main.py | 69 ++++++ ffi/python/tvm_ffi/cpp/__init__.py | 1 + ffi/python/tvm_ffi/cpp/load_inline.py | 324 ++++++++++++++++++++++++++ ffi/python/tvm_ffi/utils/__init__.py | 1 + ffi/python/tvm_ffi/utils/lockfile.py | 92 ++++++++ ffi/tests/python/test_load_inline.py | 65 ++++++ 6 files changed, 552 insertions(+) create mode 100644 ffi/examples/inline_module/main.py create mode 100644 ffi/python/tvm_ffi/cpp/__init__.py create mode 100644 ffi/python/tvm_ffi/cpp/load_inline.py create mode 100644 ffi/python/tvm_ffi/utils/__init__.py create mode 100644 ffi/python/tvm_ffi/utils/lockfile.py create mode 100644 ffi/tests/python/test_load_inline.py diff --git a/ffi/examples/inline_module/main.py b/ffi/examples/inline_module/main.py new file mode 100644 index 000000000000..b2588261373e --- /dev/null +++ b/ffi/examples/inline_module/main.py @@ -0,0 +1,69 @@ +import torch +import tvm_ffi.cpp +from tvm_ffi.module import Module + + +def main(): + mod: Module = tvm_ffi.cpp.load_inline( + name='hello', + cpp_source=r""" + void AddOne(DLTensor* x, DLTensor* y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + for (int i = 0; i < x->shape[0]; ++i) { + static_cast(y->data)[i] = static_cast(x->data)[i] + 1; + } + } + """, + cuda_source=r""" + __global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1; + } + } + + void AddOneCUDA(DLTensor* x, DLTensor* y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + + int64_t n = x->shape[0]; + int64_t nthread_per_block = 256; + int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; + // Obtain the current stream from the environment + // it will be set to torch.cuda.current_stream() when calling the function + // with torch.Tensors + cudaStream_t stream = static_cast( + TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + // launch the kernel + AddOneKernel<<>>(static_cast(x->data), + static_cast(y->data), n); + } + """, + cpp_functions={'add_one_cpu': 'AddOne'}, + cuda_functions={'add_one_cuda': 'AddOneCUDA'}, + ) + + x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) + y = torch.empty_like(x) + mod.add_one_cpu(x, y) + torch.testing.assert_close(x + 1, y) + + x_cuda = x.cuda() + y_cuda = torch.empty_like(x_cuda) + mod.add_one_cuda(x_cuda, y_cuda) + torch.testing.assert_close(x_cuda + 1, y_cuda) + + +if __name__ == "__main__": + main() diff --git a/ffi/python/tvm_ffi/cpp/__init__.py b/ffi/python/tvm_ffi/cpp/__init__.py new file mode 100644 index 000000000000..fa1644ef7b0a --- /dev/null +++ b/ffi/python/tvm_ffi/cpp/__init__.py @@ -0,0 +1 @@ +from .load_inline import load_inline diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py b/ffi/python/tvm_ffi/cpp/load_inline.py new file mode 100644 index 000000000000..f03f6bda7471 --- /dev/null +++ b/ffi/python/tvm_ffi/cpp/load_inline.py @@ -0,0 +1,324 @@ +from typing import Sequence, Optional, Mapping +import os +import sys +import glob +import hashlib +import shutil +import subprocess +import functools + +from tvm_ffi.module import Module, load_module +from tvm_ffi.utils import FileLock +from tvm_ffi.libinfo import find_include_path, find_dlpack_include_path + +IS_WINDOWS = sys.platform == "win32" + +def _hash_sources( + cpp_source: str, + cuda_source: str, + cpp_functions: Mapping[str, str], + cuda_functions: Mapping[str, str], + extra_cflags: Sequence[str], + extra_cuda_cflags: Sequence[str], + extra_ldflags: Sequence[str], + extra_include_paths: Sequence[str], +) -> str: + """Generate a unique hash for the given sources and functions.""" + m = hashlib.sha256() + m.update(cpp_source.encode("utf-8")) + m.update(cuda_source.encode("utf-8")) + for name, doc in sorted(cpp_functions.items()): + m.update(name.encode("utf-8")) + m.update(doc.encode("utf-8")) + for name, doc in sorted(cuda_functions.items()): + m.update(name.encode("utf-8")) + m.update(doc.encode("utf-8")) + for flag in extra_cflags: + m.update(flag.encode("utf-8")) + for flag in extra_cuda_cflags: + m.update(flag.encode("utf-8")) + for flag in extra_ldflags: + m.update(flag.encode("utf-8")) + for path in extra_include_paths: + m.update(path.encode("utf-8")) + return m.hexdigest()[:16] + +def _maybe_write(path: str, content: str) -> None: + """Write content to path if it does not already exist with the same content.""" + if os.path.exists(path): + with open(path, "r") as f: + existing_content = f.read() + if existing_content == content: + return + with open(path, "w") as f: + f.write(content) + + +@functools.lru_cache +def _find_cuda_home() -> Optional[str]: + """Find the CUDA install path.""" + # Guess #1 + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # Guess #2 + nvcc_path = shutil.which("nvcc") + if nvcc_path is not None: + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + else: + # Guess #3 + if IS_WINDOWS: + cuda_homes = glob.glob( + 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + if len(cuda_homes) == 0: + cuda_home = '' + else: + cuda_home = cuda_homes[0] + else: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + raise RuntimeError("Could not find CUDA installation. " + "Please set CUDA_HOME environment variable.") + return cuda_home + +def _generate_ninja_build( + name: str, + build_dir: str, + with_cuda: bool, + extra_cflags: Sequence[str], + extra_cuda_cflags: Sequence[str], + extra_ldflags: Sequence[str], + extra_include_paths: Sequence[str], +) -> str: + """ Generate the content of build.ninja for building the module. """ + default_include_paths = [ + find_include_path(), + find_dlpack_include_path(), + ] + + if IS_WINDOWS: + default_cflags = ['/std:c++17'] + default_cuda_cflags = ['-Xcompiler', '/std:c++17', '/O2'] + default_ldflags = ['/DLL'] + else: + default_cflags = ['-std=c++17', '-fPIC', '-O2'] + default_cuda_cflags = ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] + default_ldflags = ['-shared'] + + if with_cuda: + default_ldflags += ['-L{}'.format(os.path.join(_find_cuda_home(), 'lib64')), '-lcudart'] + + cflags = default_cflags + [flag.strip() for flag in extra_cflags] + cuda_cflags = default_cuda_cflags + [flag.strip() for flag in extra_cuda_cflags] + ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags] + include_paths = default_include_paths + [os.path.abspath(path) for path in extra_include_paths] + + # append include paths + for path in include_paths: + cflags.append('-I{}'.format(path)) + cuda_cflags.append('-I{}'.format(path)) + + # flags + ninja = [] + ninja.append('ninja_required_version = 1.3') + ninja.append('cxx = {}'.format(os.environ.get("CXX", 'cl' if IS_WINDOWS else 'c++'))) + ninja.append('cflags = {}'.format(' '.join(cflags))) + if with_cuda: + ninja.append('nvcc = {}'.format(os.path.join(_find_cuda_home(), 'bin', 'nvcc'))) + ninja.append('cuda_cflags = {}'.format(' '.join(cuda_cflags))) + ninja.append('ldflags = {}'.format(' '.join(ldflags))) + + # rules + ninja.append('') + ninja.append('rule compile') + ninja.append(' depfile = $out.d') + ninja.append(' deps = gcc') + ninja.append(' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out') + ninja.append('') + + if with_cuda: + ninja.append('rule compile_cuda') + ninja.append(' depfile = $out.d') + ninja.append(' deps = gcc') + ninja.append(' command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out') + ninja.append('') + + ninja.append('rule link') + ninja.append(' command = $cxx $in $ldflags -o $out') + ninja.append('') + + # build targets + ninja.append('build main.o: compile {}'.format(os.path.abspath(os.path.join(build_dir, 'main.cpp')))) + if with_cuda: + ninja.append('build cuda.o: compile_cuda {}'.format(os.path.abspath(os.path.join(build_dir, 'cuda.cu')))) + ninja.append('build {}.so: link main.o{}'.format(name, ' cuda.o' if with_cuda else '')) + ninja.append('') + + # default target + ninja.append('default {}.so'.format(name)) + ninja.append('') + return '\n'.join(ninja) + + +def _build_ninja(build_dir: str) -> None: + """ Build the module in the given build directory using ninja. """ + command = ['ninja', '-v'] + num_workers = os.environ.get("MAX_JOBS", None) + if num_workers is not None: + command += ['-j', num_workers] + status = subprocess.run( + args=command, + cwd=build_dir, + capture_output=True, + ) + if status.returncode != 0: + msg = ['ninja exited with status {}'.format(status.returncode)] + if status.stdout: + msg.append('stdout:\n{}'.format(status.stdout.decode('utf-8'))) + if status.stderr: + msg.append('stderr:\n{}'.format(status.stderr.decode('utf-8'))) + + raise RuntimeError('\n'.join(msg)) + + +def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str: + """ Decorate the given source code with TVM FFI export macros. """ + sources = [ + '#include ', + '#include ', + '#include ', + '#include ', + '', + source, + ] + + for exported_name, func_name_in_source in functions.items(): + sources.append(f'TVM_FFI_DLL_EXPORT_TYPED_FUNC({exported_name}, {func_name_in_source});') + sources.append('') + + return '\n'.join(sources) + + +def load_inline( + name: str, + *, + cpp_source: str | None = None, + cuda_source: str | None = None, + cpp_functions: Mapping[str, str] | None = None, + cuda_functions: Mapping[str, str] | None = None, + extra_cflags: Sequence[str] | None = None, + extra_cuda_cflags: Sequence[str] | None = None, + extra_ldflags: Sequence[str] | None = None, + extra_include_paths: Sequence[str] | None = None, +) -> Module: + """ Compile and load a C++/CUDA tvm ffi module from inline source code. + + This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_source and cuda_source + are compiled to an object file, and then linked together into a shared library. It's possible to only provide + cpp_source or cuda_source. + + The `cpp_functions` and `cuda_functions` parameters are used to specify which functions in the source code + should be exported to the tvm ffi module. The keys of the mapping are the names of the exported functions, and the + values are the names of the functions in the source code. The exported name and the function name in the source code + must be different. The exported name must be a valid C identifier while the function name in the source code can + contain namespace qualifiers. + + Extra compiler and linker flags can be provided via the `extra_cflags`, `extra_cuda_cflags`, and `extra_ldflags` + parameters. The default flags are generally sufficient for most use cases, but you may need to provide additional + flags for your specific use case. + + The include dir of tvm ffi and dlpack are used by default for linker to find the headers. Thus, you can include + any header from tvm ffi and dlpack in your source code. You can also provide additional include paths via the + `extra_include_paths` parameter and include custom headers in your source code. + + The compiled shared library is cached in a cache directory to avoid recompilation. The cache directory can be + specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified, the default cache directory is + `~/.cache/tvm-ffi`. + + Parameters + ---------- + name: str + The name of the tvm ffi module. + cpp_source: str, optional + The C++ source code. + cuda_source: str, optional + The CUDA source code. + cpp_functions: Mapping[str, str], optional + The mapping from the exported function name to the function name in the C++ source code. + cuda_functions: Mapping[str, str], optional + The mapping from the exported function name to the function name in the CUDA source code. + extra_cflags: Sequence[str], optional + The extra compiler flags for C++ compilation. + The default flags are: + - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2'] + - On Windows: ['/std:c++17'] + extra_cuda_cflags: + The extra compiler flags for CUDA compilation. + The default flags are: + - On Linux/macOS: ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] + - On Windows: ['-Xcompiler', '/std:c++17', '/O2'] + extra_ldflags: Sequence[str], optional + The extra linker flags. + The default flags are: + - On Linux/macOS: ['-shared'] + - On Windows: ['/DLL'] + extra_include_paths: Sequence[str], optional + The extra include paths. + The default include paths are: + - The include path of tvm ffi + Returns + ------- + mod: Module + The loaded tvm ffi module. + """ + if cpp_source is None: + cpp_source = '' + if cuda_source is None: + cuda_source = '' + if cpp_functions is None: + cpp_functions = {} + if cuda_functions is None: + cuda_functions = {} + extra_ldflags = extra_ldflags or [] + extra_cflags = extra_cflags or [] + extra_cuda_cflags = extra_cuda_cflags or [] + extra_include_paths = extra_include_paths or [] + + # whether we have cuda source in this module + with_cuda = len(cuda_source.strip()) > 0 + + # add function registration code to sources + cpp_source = _decorate_with_tvm_ffi(cpp_source, cpp_functions) + cuda_source = _decorate_with_tvm_ffi(cuda_source, cuda_functions) + + # determine the cache dir for the built module + cache_dir = os.path.join( + os.environ.get("TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi")) + ) + source_hash: str = _hash_sources( + cpp_source, cuda_source, cpp_functions, cuda_functions, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths + ) + build_dir: str = os.path.join(cache_dir, '{}_{}'.format(name, source_hash)) + os.makedirs(build_dir, exist_ok=True) + + # generate build.ninja + ninja_source = _generate_ninja_build( + name=name, + build_dir=build_dir, + with_cuda=with_cuda, + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_ldflags=extra_ldflags, + extra_include_paths=extra_include_paths + ) + + with FileLock(os.path.join(build_dir, "lock")): + # write source files and build.ninja if they do not already exist + _maybe_write(os.path.join(build_dir, "main.cpp"), cpp_source) + if with_cuda: + _maybe_write(os.path.join(build_dir, "cuda.cu"), cuda_source) + _maybe_write(os.path.join(build_dir, "build.ninja"), ninja_source) + + # build the module + _build_ninja(build_dir) + + return load_module(os.path.join(build_dir, "{}.so".format(name))) diff --git a/ffi/python/tvm_ffi/utils/__init__.py b/ffi/python/tvm_ffi/utils/__init__.py new file mode 100644 index 000000000000..119952765f7a --- /dev/null +++ b/ffi/python/tvm_ffi/utils/__init__.py @@ -0,0 +1 @@ +from .lockfile import FileLock diff --git a/ffi/python/tvm_ffi/utils/lockfile.py b/ffi/python/tvm_ffi/utils/lockfile.py new file mode 100644 index 000000000000..8fb015a53c56 --- /dev/null +++ b/ffi/python/tvm_ffi/utils/lockfile.py @@ -0,0 +1,92 @@ +import os +import sys +import time + +# Platform-specific imports for file locking +if sys.platform == "win32": + import msvcrt +else: + import fcntl + + +class FileLock: + """ + A cross-platform file locking mechanism using Python's standard library. + This class implements an advisory lock, which must be respected by all + cooperating processes. + """ + + def __init__(self, lock_file_path): + self.lock_file_path = lock_file_path + self._file_descriptor = None + + def __enter__(self): + """ + Context manager protocol: acquire the lock upon entering the 'with' block. + This method will block indefinitely until the lock is acquired. + """ + self.blocking_acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Context manager protocol: release the lock upon exiting the 'with' block. + """ + self.release() + return False # Propagate exceptions, if any + + def acquire(self): + """ + Acquires an exclusive, non-blocking lock on the file. + Returns True if the lock was acquired, False otherwise. + """ + try: + if sys.platform == "win32": + self._file_descriptor = os.open(self.lock_file_path, os.O_RDWR | os.O_CREAT | os.O_BINARY) + msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1) + else: # Unix-like systems + self._file_descriptor = os.open(self.lock_file_path, os.O_WRONLY | os.O_CREAT) + fcntl.flock(self._file_descriptor, fcntl.LOCK_EX | fcntl.LOCK_NB) + return True + except (IOError, BlockingIOError): + if self._file_descriptor is not None: + os.close(self._file_descriptor) + self._file_descriptor = None + return False + except Exception as e: + if self._file_descriptor is not None: + os.close(self._file_descriptor) + self._file_descriptor = None + raise RuntimeError(f"An unexpected error occurred: {e}") + + def blocking_acquire(self, timeout=None, poll_interval=0.1): + """ + Waits until an exclusive lock can be acquired, with an optional timeout. + + Args: + timeout (float): The maximum time to wait for the lock in seconds. + A value of None means wait indefinitely. + poll_interval (float): The time to wait between lock attempts in seconds. + """ + start_time = time.time() + while True: + if self.acquire(): + return True + + # Check for timeout + if timeout is not None and (time.time() - start_time) > timeout: + raise TimeoutError(f"Failed to acquire lock on '{self.lock_file_path}' after {timeout} seconds.") + + time.sleep(poll_interval) + + def release(self): + """ + Releases the lock and closes the file descriptor. + """ + if self._file_descriptor is not None: + if sys.platform == "win32": + msvcrt.locking(self._file_descriptor, msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(self._file_descriptor, fcntl.LOCK_UN) + os.close(self._file_descriptor) + self._file_descriptor = None diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py new file mode 100644 index 000000000000..c9f8103421ea --- /dev/null +++ b/ffi/tests/python/test_load_inline.py @@ -0,0 +1,65 @@ +import torch +import tvm_ffi.cpp +from tvm_ffi.module import Module + + +def test_load_inline(): + mod: Module = tvm_ffi.cpp.load_inline( + name='hello', + cpp_source=r""" + void AddOne(DLTensor* x, DLTensor* y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + for (int i = 0; i < x->shape[0]; ++i) { + static_cast(y->data)[i] = static_cast(x->data)[i] + 1; + } + } + """, + cuda_source=r""" + __global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1; + } + } + + void AddOneCUDA(DLTensor* x, DLTensor* y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + + int64_t n = x->shape[0]; + int64_t nthread_per_block = 256; + int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; + // Obtain the current stream from the environment + // it will be set to torch.cuda.current_stream() when calling the function + // with torch.Tensors + cudaStream_t stream = static_cast( + TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + // launch the kernel + AddOneKernel<<>>(static_cast(x->data), + static_cast(y->data), n); + } + """, + cpp_functions={'add_one_cpu': 'AddOne'}, + cuda_functions={'add_one_cuda': 'AddOneCUDA'}, + ) + + x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) + y = torch.empty_like(x) + mod.add_one_cpu(x, y) + torch.testing.assert_close(x + 1, y) + + x_cuda = x.cuda() + y_cuda = torch.empty_like(x_cuda) + mod.add_one_cuda(x_cuda, y_cuda) + torch.testing.assert_close(x_cuda + 1, y_cuda) From d27030256507c20089d1e53eeaecd981aef5cd01 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Sep 2025 00:33:14 -0400 Subject: [PATCH 2/5] format & lint --- ffi/examples/inline_module/main.py | 23 +++- ffi/python/tvm_ffi/cpp/__init__.py | 17 +++ ffi/python/tvm_ffi/cpp/load_inline.py | 181 +++++++++++++++----------- ffi/python/tvm_ffi/utils/__init__.py | 17 +++ ffi/python/tvm_ffi/utils/lockfile.py | 25 +++- ffi/tests/python/test_load_inline.py | 23 +++- 6 files changed, 202 insertions(+), 84 deletions(-) diff --git a/ffi/examples/inline_module/main.py b/ffi/examples/inline_module/main.py index b2588261373e..574d55c67824 100644 --- a/ffi/examples/inline_module/main.py +++ b/ffi/examples/inline_module/main.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import torch import tvm_ffi.cpp from tvm_ffi.module import Module @@ -5,7 +22,7 @@ def main(): mod: Module = tvm_ffi.cpp.load_inline( - name='hello', + name="hello", cpp_source=r""" void AddOne(DLTensor* x, DLTensor* y) { // implementation of a library function @@ -50,8 +67,8 @@ def main(): static_cast(y->data), n); } """, - cpp_functions={'add_one_cpu': 'AddOne'}, - cuda_functions={'add_one_cuda': 'AddOneCUDA'}, + cpp_functions={"add_one_cpu": "AddOne"}, + cuda_functions={"add_one_cuda": "AddOneCUDA"}, ) x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) diff --git a/ffi/python/tvm_ffi/cpp/__init__.py b/ffi/python/tvm_ffi/cpp/__init__.py index fa1644ef7b0a..632698f4431a 100644 --- a/ffi/python/tvm_ffi/cpp/__init__.py +++ b/ffi/python/tvm_ffi/cpp/__init__.py @@ -1 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from .load_inline import load_inline diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py b/ffi/python/tvm_ffi/cpp/load_inline.py index f03f6bda7471..1d83894f652c 100644 --- a/ffi/python/tvm_ffi/cpp/load_inline.py +++ b/ffi/python/tvm_ffi/cpp/load_inline.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from typing import Sequence, Optional, Mapping import os import sys @@ -13,6 +30,7 @@ IS_WINDOWS = sys.platform == "win32" + def _hash_sources( cpp_source: str, cuda_source: str, @@ -43,6 +61,7 @@ def _hash_sources( m.update(path.encode("utf-8")) return m.hexdigest()[:16] + def _maybe_write(path: str, content: str) -> None: """Write content to path if it does not already exist with the same content.""" if os.path.exists(path): @@ -58,7 +77,7 @@ def _maybe_write(path: str, content: str) -> None: def _find_cuda_home() -> Optional[str]: """Find the CUDA install path.""" # Guess #1 - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") if cuda_home is None: # Guess #2 nvcc_path = shutil.which("nvcc") @@ -67,19 +86,21 @@ def _find_cuda_home() -> Optional[str]: else: # Guess #3 if IS_WINDOWS: - cuda_homes = glob.glob( - 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") if len(cuda_homes) == 0: - cuda_home = '' + cuda_home = "" else: cuda_home = cuda_homes[0] else: - cuda_home = '/usr/local/cuda' + cuda_home = "/usr/local/cuda" if not os.path.exists(cuda_home): - raise RuntimeError("Could not find CUDA installation. " - "Please set CUDA_HOME environment variable.") + raise RuntimeError( + "Could not find CUDA installation. " + "Please set CUDA_HOME environment variable." + ) return cuda_home + def _generate_ninja_build( name: str, build_dir: str, @@ -89,23 +110,20 @@ def _generate_ninja_build( extra_ldflags: Sequence[str], extra_include_paths: Sequence[str], ) -> str: - """ Generate the content of build.ninja for building the module. """ - default_include_paths = [ - find_include_path(), - find_dlpack_include_path(), - ] + """Generate the content of build.ninja for building the module.""" + default_include_paths = [find_include_path(), find_dlpack_include_path()] if IS_WINDOWS: - default_cflags = ['/std:c++17'] - default_cuda_cflags = ['-Xcompiler', '/std:c++17', '/O2'] - default_ldflags = ['/DLL'] + default_cflags = ["/std:c++17"] + default_cuda_cflags = ["-Xcompiler", "/std:c++17", "/O2"] + default_ldflags = ["/DLL"] else: - default_cflags = ['-std=c++17', '-fPIC', '-O2'] - default_cuda_cflags = ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] - default_ldflags = ['-shared'] + default_cflags = ["-std=c++17", "-fPIC", "-O2"] + default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"] + default_ldflags = ["-shared"] if with_cuda: - default_ldflags += ['-L{}'.format(os.path.join(_find_cuda_home(), 'lib64')), '-lcudart'] + default_ldflags += ["-L{}".format(os.path.join(_find_cuda_home(), "lib64")), "-lcudart"] cflags = default_cflags + [flag.strip() for flag in extra_cflags] cuda_cflags = default_cuda_cflags + [flag.strip() for flag in extra_cuda_cflags] @@ -114,88 +132,92 @@ def _generate_ninja_build( # append include paths for path in include_paths: - cflags.append('-I{}'.format(path)) - cuda_cflags.append('-I{}'.format(path)) + cflags.append("-I{}".format(path)) + cuda_cflags.append("-I{}".format(path)) # flags ninja = [] - ninja.append('ninja_required_version = 1.3') - ninja.append('cxx = {}'.format(os.environ.get("CXX", 'cl' if IS_WINDOWS else 'c++'))) - ninja.append('cflags = {}'.format(' '.join(cflags))) + ninja.append("ninja_required_version = 1.3") + ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))) + ninja.append("cflags = {}".format(" ".join(cflags))) if with_cuda: - ninja.append('nvcc = {}'.format(os.path.join(_find_cuda_home(), 'bin', 'nvcc'))) - ninja.append('cuda_cflags = {}'.format(' '.join(cuda_cflags))) - ninja.append('ldflags = {}'.format(' '.join(ldflags))) + ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin", "nvcc"))) + ninja.append("cuda_cflags = {}".format(" ".join(cuda_cflags))) + ninja.append("ldflags = {}".format(" ".join(ldflags))) # rules - ninja.append('') - ninja.append('rule compile') - ninja.append(' depfile = $out.d') - ninja.append(' deps = gcc') - ninja.append(' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out') - ninja.append('') + ninja.append("") + ninja.append("rule compile") + ninja.append(" depfile = $out.d") + ninja.append(" deps = gcc") + ninja.append(" command = $cxx -MMD -MF $out.d $cflags -c $in -o $out") + ninja.append("") if with_cuda: - ninja.append('rule compile_cuda') - ninja.append(' depfile = $out.d') - ninja.append(' deps = gcc') - ninja.append(' command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out') - ninja.append('') - - ninja.append('rule link') - ninja.append(' command = $cxx $in $ldflags -o $out') - ninja.append('') + ninja.append("rule compile_cuda") + ninja.append(" depfile = $out.d") + ninja.append(" deps = gcc") + ninja.append( + " command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out" + ) + ninja.append("") + + ninja.append("rule link") + ninja.append(" command = $cxx $in $ldflags -o $out") + ninja.append("") # build targets - ninja.append('build main.o: compile {}'.format(os.path.abspath(os.path.join(build_dir, 'main.cpp')))) + ninja.append( + "build main.o: compile {}".format(os.path.abspath(os.path.join(build_dir, "main.cpp"))) + ) if with_cuda: - ninja.append('build cuda.o: compile_cuda {}'.format(os.path.abspath(os.path.join(build_dir, 'cuda.cu')))) - ninja.append('build {}.so: link main.o{}'.format(name, ' cuda.o' if with_cuda else '')) - ninja.append('') + ninja.append( + "build cuda.o: compile_cuda {}".format( + os.path.abspath(os.path.join(build_dir, "cuda.cu")) + ) + ) + ninja.append("build {}.so: link main.o{}".format(name, " cuda.o" if with_cuda else "")) + ninja.append("") # default target - ninja.append('default {}.so'.format(name)) - ninja.append('') - return '\n'.join(ninja) + ninja.append("default {}.so".format(name)) + ninja.append("") + return "\n".join(ninja) def _build_ninja(build_dir: str) -> None: - """ Build the module in the given build directory using ninja. """ - command = ['ninja', '-v'] + """Build the module in the given build directory using ninja.""" + command = ["ninja", "-v"] num_workers = os.environ.get("MAX_JOBS", None) if num_workers is not None: - command += ['-j', num_workers] - status = subprocess.run( - args=command, - cwd=build_dir, - capture_output=True, - ) + command += ["-j", num_workers] + status = subprocess.run(args=command, cwd=build_dir, capture_output=True) if status.returncode != 0: - msg = ['ninja exited with status {}'.format(status.returncode)] + msg = ["ninja exited with status {}".format(status.returncode)] if status.stdout: - msg.append('stdout:\n{}'.format(status.stdout.decode('utf-8'))) + msg.append("stdout:\n{}".format(status.stdout.decode("utf-8"))) if status.stderr: - msg.append('stderr:\n{}'.format(status.stderr.decode('utf-8'))) + msg.append("stderr:\n{}".format(status.stderr.decode("utf-8"))) - raise RuntimeError('\n'.join(msg)) + raise RuntimeError("\n".join(msg)) def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str: - """ Decorate the given source code with TVM FFI export macros. """ + """Decorate the given source code with TVM FFI export macros.""" sources = [ - '#include ', - '#include ', - '#include ', - '#include ', - '', + "#include ", + "#include ", + "#include ", + "#include ", + "", source, ] for exported_name, func_name_in_source in functions.items(): - sources.append(f'TVM_FFI_DLL_EXPORT_TYPED_FUNC({exported_name}, {func_name_in_source});') - sources.append('') + sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({exported_name}, {func_name_in_source});") + sources.append("") - return '\n'.join(sources) + return "\n".join(sources) def load_inline( @@ -210,7 +232,7 @@ def load_inline( extra_ldflags: Sequence[str] | None = None, extra_include_paths: Sequence[str] | None = None, ) -> Module: - """ Compile and load a C++/CUDA tvm ffi module from inline source code. + """Compile and load a C++/CUDA tvm ffi module from inline source code. This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_source and cuda_source are compiled to an object file, and then linked together into a shared library. It's possible to only provide @@ -271,9 +293,9 @@ def load_inline( The loaded tvm ffi module. """ if cpp_source is None: - cpp_source = '' + cpp_source = "" if cuda_source is None: - cuda_source = '' + cuda_source = "" if cpp_functions is None: cpp_functions = {} if cuda_functions is None: @@ -295,9 +317,16 @@ def load_inline( os.environ.get("TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi")) ) source_hash: str = _hash_sources( - cpp_source, cuda_source, cpp_functions, cuda_functions, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths + cpp_source, + cuda_source, + cpp_functions, + cuda_functions, + extra_cflags, + extra_cuda_cflags, + extra_ldflags, + extra_include_paths, ) - build_dir: str = os.path.join(cache_dir, '{}_{}'.format(name, source_hash)) + build_dir: str = os.path.join(cache_dir, "{}_{}".format(name, source_hash)) os.makedirs(build_dir, exist_ok=True) # generate build.ninja @@ -308,7 +337,7 @@ def load_inline( extra_cflags=extra_cflags, extra_cuda_cflags=extra_cuda_cflags, extra_ldflags=extra_ldflags, - extra_include_paths=extra_include_paths + extra_include_paths=extra_include_paths, ) with FileLock(os.path.join(build_dir, "lock")): diff --git a/ffi/python/tvm_ffi/utils/__init__.py b/ffi/python/tvm_ffi/utils/__init__.py index 119952765f7a..543bd0f84100 100644 --- a/ffi/python/tvm_ffi/utils/__init__.py +++ b/ffi/python/tvm_ffi/utils/__init__.py @@ -1 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from .lockfile import FileLock diff --git a/ffi/python/tvm_ffi/utils/lockfile.py b/ffi/python/tvm_ffi/utils/lockfile.py index 8fb015a53c56..3b3197e2d8e0 100644 --- a/ffi/python/tvm_ffi/utils/lockfile.py +++ b/ffi/python/tvm_ffi/utils/lockfile.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import os import sys import time @@ -42,7 +59,9 @@ def acquire(self): """ try: if sys.platform == "win32": - self._file_descriptor = os.open(self.lock_file_path, os.O_RDWR | os.O_CREAT | os.O_BINARY) + self._file_descriptor = os.open( + self.lock_file_path, os.O_RDWR | os.O_CREAT | os.O_BINARY + ) msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1) else: # Unix-like systems self._file_descriptor = os.open(self.lock_file_path, os.O_WRONLY | os.O_CREAT) @@ -75,7 +94,9 @@ def blocking_acquire(self, timeout=None, poll_interval=0.1): # Check for timeout if timeout is not None and (time.time() - start_time) > timeout: - raise TimeoutError(f"Failed to acquire lock on '{self.lock_file_path}' after {timeout} seconds.") + raise TimeoutError( + f"Failed to acquire lock on '{self.lock_file_path}' after {timeout} seconds." + ) time.sleep(poll_interval) diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index c9f8103421ea..6298d2b9c219 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import torch import tvm_ffi.cpp from tvm_ffi.module import Module @@ -5,7 +22,7 @@ def test_load_inline(): mod: Module = tvm_ffi.cpp.load_inline( - name='hello', + name="hello", cpp_source=r""" void AddOne(DLTensor* x, DLTensor* y) { // implementation of a library function @@ -50,8 +67,8 @@ def test_load_inline(): static_cast(y->data), n); } """, - cpp_functions={'add_one_cpu': 'AddOne'}, - cuda_functions={'add_one_cuda': 'AddOneCUDA'}, + cpp_functions={"add_one_cpu": "AddOne"}, + cuda_functions={"add_one_cuda": "AddOneCUDA"}, ) x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) From 1d7fc2aa55401fbfbe251b35d119193b91ed10eb Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Sep 2025 13:27:45 -0400 Subject: [PATCH 3/5] revolse cuda targets --- ffi/python/tvm_ffi/cpp/load_inline.py | 29 ++++++++ ffi/tests/python/test_load_inline.py | 95 ++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 9 deletions(-) diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py b/ffi/python/tvm_ffi/cpp/load_inline.py index 1d83894f652c..b402d2b63c09 100644 --- a/ffi/python/tvm_ffi/cpp/load_inline.py +++ b/ffi/python/tvm_ffi/cpp/load_inline.py @@ -101,6 +101,33 @@ def _find_cuda_home() -> Optional[str]: return cuda_home +def _get_cuda_target() -> str: + """Get the CUDA target architecture flag.""" + if 'TVM_FFI_CUDA_ARCH_LIST' in os.environ: + arch_list = os.environ['TVM_FFI_CUDA_ARCH_LIST'].split() # e.g., "8.9 9.0a" + flags = [] + for arch in arch_list: + if len(arch.split('.')) != 2: + raise ValueError(f"Invalid CUDA architecture: {arch}") + major, minor = arch.split('.') + flags.append(f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}") + return " ".join(flags) + else: + # + try: + status = subprocess.run( + args=["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], + capture_output=True, + check=True, + ) + compute_cap = status.stdout.decode("utf-8").strip().split("\n")[0] + major, minor = compute_cap.split(".") + return f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}" + except Exception: + # fallback to a reasonable default + return "-gencode=arch=compute_70,code=sm_70" + + def _generate_ninja_build( name: str, build_dir: str, @@ -123,6 +150,8 @@ def _generate_ninja_build( default_ldflags = ["-shared"] if with_cuda: + # determine the compute capability of the current GPU + default_cuda_cflags += [_get_cuda_target()] default_ldflags += ["-L{}".format(os.path.join(_find_cuda_home(), "lib64")), "-lcudart"] cflags = default_cflags + [flag.strip() for flag in extra_cflags] diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index 6298d2b9c219..97062d164df1 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -15,12 +15,88 @@ # specific language governing permissions and limitations # under the License. -import torch +import numpy + +try: + import torch +except ImportError: + torch = None + import tvm_ffi.cpp from tvm_ffi.module import Module -def test_load_inline(): +def test_load_inline_cpp(): + mod: Module = tvm_ffi.cpp.load_inline( + name="hello", + cpp_source=r""" + void AddOne(DLTensor* x, DLTensor* y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + for (int i = 0; i < x->shape[0]; ++i) { + static_cast(y->data)[i] = static_cast(x->data)[i] + 1; + } + } + """, + cpp_functions={"add_one_cpu": "AddOne"}, + ) + + x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) + y = numpy.empty_like(x) + mod.add_one_cpu(x, y) + numpy.testing.assert_equal(x + 1, y) + + + +def test_load_inline_cuda(): + mod: Module = tvm_ffi.cpp.load_inline( + name="hello", + cuda_source=r""" + __global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1; + } + } + + void AddOneCUDA(DLTensor* x, DLTensor* y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + + int64_t n = x->shape[0]; + int64_t nthread_per_block = 256; + int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; + // Obtain the current stream from the environment + // it will be set to torch.cuda.current_stream() when calling the function + // with torch.Tensors + cudaStream_t stream = static_cast( + TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + // launch the kernel + AddOneKernel<<>>(static_cast(x->data), + static_cast(y->data), n); + } + """, + cuda_functions={"add_one_cuda": "AddOneCUDA"}, + ) + + if torch is not None: + x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") + y_cuda = torch.empty_like(x_cuda) + mod.add_one_cuda(x_cuda, y_cuda) + torch.testing.assert_close(x_cuda + 1, y_cuda) + + +def test_load_inline_both(): mod: Module = tvm_ffi.cpp.load_inline( name="hello", cpp_source=r""" @@ -71,12 +147,13 @@ def test_load_inline(): cuda_functions={"add_one_cuda": "AddOneCUDA"}, ) - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) - y = torch.empty_like(x) + x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) + y = numpy.empty_like(x) mod.add_one_cpu(x, y) - torch.testing.assert_close(x + 1, y) + numpy.testing.assert_equal(x + 1, y) - x_cuda = x.cuda() - y_cuda = torch.empty_like(x_cuda) - mod.add_one_cuda(x_cuda, y_cuda) - torch.testing.assert_close(x_cuda + 1, y_cuda) + if torch is not None: + x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") + y_cuda = torch.empty_like(x_cuda) + mod.add_one_cuda(x_cuda, y_cuda) + torch.testing.assert_close(x_cuda + 1, y_cuda) From ebf401d4bb223b8b218a332db650fb1162ea5746 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Sep 2025 13:36:47 -0400 Subject: [PATCH 4/5] skip tests require cuda --- ffi/tests/python/test_load_inline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index 97062d164df1..bb14ae9792c2 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import pytest import numpy try: @@ -52,7 +53,7 @@ def test_load_inline_cpp(): numpy.testing.assert_equal(x + 1, y) - +@pytest.mark.skip(reason="Requires CUDA") def test_load_inline_cuda(): mod: Module = tvm_ffi.cpp.load_inline( name="hello", @@ -96,6 +97,7 @@ def test_load_inline_cuda(): torch.testing.assert_close(x_cuda + 1, y_cuda) +@pytest.mark.skip(reason="Requires CUDA") def test_load_inline_both(): mod: Module = tvm_ffi.cpp.load_inline( name="hello", From 14529ef5ea0426370a56fdfc4d09a16b06967c06 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Sep 2025 13:39:46 -0400 Subject: [PATCH 5/5] format & lint --- ffi/python/tvm_ffi/convert.py | 3 +-- ffi/python/tvm_ffi/cpp/load_inline.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ffi/python/tvm_ffi/convert.py b/ffi/python/tvm_ffi/convert.py index 5b25ddae259b..5b28fa354537 100644 --- a/ffi/python/tvm_ffi/convert.py +++ b/ffi/python/tvm_ffi/convert.py @@ -56,8 +56,7 @@ def convert(value: Any) -> Any: return None elif hasattr(value, "__dlpack__"): return core.from_dlpack( - value, - required_alignment=core.__dlpack_auto_import_required_alignment__, + value, required_alignment=core.__dlpack_auto_import_required_alignment__ ) elif isinstance(value, Exception): return core._convert_to_ffi_error(value) diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py b/ffi/python/tvm_ffi/cpp/load_inline.py index b402d2b63c09..a9ec1c39977d 100644 --- a/ffi/python/tvm_ffi/cpp/load_inline.py +++ b/ffi/python/tvm_ffi/cpp/load_inline.py @@ -103,13 +103,13 @@ def _find_cuda_home() -> Optional[str]: def _get_cuda_target() -> str: """Get the CUDA target architecture flag.""" - if 'TVM_FFI_CUDA_ARCH_LIST' in os.environ: - arch_list = os.environ['TVM_FFI_CUDA_ARCH_LIST'].split() # e.g., "8.9 9.0a" + if "TVM_FFI_CUDA_ARCH_LIST" in os.environ: + arch_list = os.environ["TVM_FFI_CUDA_ARCH_LIST"].split() # e.g., "8.9 9.0a" flags = [] for arch in arch_list: - if len(arch.split('.')) != 2: + if len(arch.split(".")) != 2: raise ValueError(f"Invalid CUDA architecture: {arch}") - major, minor = arch.split('.') + major, minor = arch.split(".") flags.append(f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}") return " ".join(flags) else: