From 598c1057490f63eb3d85ece56ca760a3bd371e3d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 25 Apr 2025 13:31:01 -0700 Subject: [PATCH 01/16] PyTorch example --- cuda_core/examples/pytorch_example.py | 121 ++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 cuda_core/examples/pytorch_example.py diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py new file mode 100644 index 0000000000..9893d17b07 --- /dev/null +++ b/cuda_core/examples/pytorch_example.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +import sys +import torch +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch, Stream + +# SAXPY kernel - passing a as a pointer to avoid any type issues +code = """ +template +__global__ void saxpy_kernel(const T* a, const T* x, const T* y, T* out, size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < N) { + // Dereference a to get the scalar value + out[tid] = (*a) * x[tid] + y[tid]; + } +} +""" + +dev = Device() +dev.set_current() + +# Get PyTorch's current stream +pt_stream = torch.cuda.current_stream() +print(f"PyTorch stream: {pt_stream}") + +# Create a wrapper class that implements __cuda_stream__ +class PyTorchStreamWrapper: + def __init__(self, pt_stream): + self.pt_stream = pt_stream + + def __cuda_stream__(self): + # Extract the stream ID from PyTorch's stream object + if hasattr(self.pt_stream, 'cuda_stream'): + stream_id = self.pt_stream.cuda_stream + else: + # Try to extract from string representation + stream_str = str(self.pt_stream) + try: + stream_id = int(stream_str.split('cuda_stream=0x')[1].strip('>'), 16) + except (IndexError, ValueError): + stream_id = 0 # Default to 0 if we can't extract it + + print(f"Using PyTorch stream ID: {stream_id}") + return (0, stream_id) # Return format required by CUDA Python + +# Create a wrapper for the PyTorch stream +pt_stream_wrapper = PyTorchStreamWrapper(pt_stream) + +# Initialize a CUDA Python Stream from the PyTorch stream +s = Stream._init(obj=pt_stream_wrapper) +print(f"Successfully created CUDA Python stream from PyTorch stream") + +# prepare program +arch = "".join(f"{i}" for i in dev.compute_capability) +program_options = ProgramOptions(std="c++11", arch=f"sm_{arch}") +prog = Program(code, code_type="c++", options=program_options) +mod = prog.compile( + "cubin", + logs=sys.stdout, + name_expressions=("saxpy_kernel", "saxpy_kernel"), +) + +# Run in single precision +ker = mod.get_kernel("saxpy_kernel") +dtype = torch.float32 + +# prepare input/output +size = 64 +# Use a single element tensor for 'a' +a = torch.tensor([10.0], dtype=dtype, device='cuda') +x = torch.rand(size, dtype=dtype, device='cuda') +y = torch.rand(size, dtype=dtype, device='cuda') +out = torch.empty_like(x) + +# prepare launch +block = 32 +grid = int((size + block - 1) // block) +config = LaunchConfig(grid=grid, block=block) +ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size) + +# launch kernel on our stream +launch(s, config, ker, *ker_args) + +# Wait for our CUDA kernel to complete +s.sync() + +# check result +assert torch.allclose(out, a.item() * x + y) +print("Single precision test passed!") + +# let's repeat again with double precision +ker = mod.get_kernel("saxpy_kernel") +dtype = torch.float64 + +# prepare input +size = 128 +# Use a single element tensor for 'a' +a = torch.tensor([42.0], dtype=dtype, device='cuda') +x = torch.rand(size, dtype=dtype, device='cuda') +y = torch.rand(size, dtype=dtype, device='cuda') + +# prepare output +out = torch.empty_like(x) + +# prepare launch +block = 64 +grid = int((size + block - 1) // block) +config = LaunchConfig(grid=grid, block=block) +ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size) + +# launch kernel on PyTorch's stream +launch(s, config, ker, *ker_args) + +# Wait for our CUDA kernel to complete +s.sync() + +# check result +assert torch.allclose(out, a.item() * x + y) +print("Double precision test passed!") +print("All tests passed successfully!") From ff161bd75fdc56a68d44880f17ddf7b9d39fb54c Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 25 Apr 2025 13:38:17 -0700 Subject: [PATCH 02/16] lint --- cuda_core/examples/pytorch_example.py | 31 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index 9893d17b07..143d6d9727 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -1,9 +1,14 @@ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +## Usage: pip install "cuda-core[cu12]" +## python python_example.py import sys + import torch -from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch, Stream + +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, Stream, launch # SAXPY kernel - passing a as a pointer to avoid any type issues code = """ @@ -24,32 +29,34 @@ pt_stream = torch.cuda.current_stream() print(f"PyTorch stream: {pt_stream}") + # Create a wrapper class that implements __cuda_stream__ class PyTorchStreamWrapper: def __init__(self, pt_stream): self.pt_stream = pt_stream - + def __cuda_stream__(self): # Extract the stream ID from PyTorch's stream object - if hasattr(self.pt_stream, 'cuda_stream'): + if hasattr(self.pt_stream, "cuda_stream"): stream_id = self.pt_stream.cuda_stream else: # Try to extract from string representation stream_str = str(self.pt_stream) try: - stream_id = int(stream_str.split('cuda_stream=0x')[1].strip('>'), 16) + stream_id = int(stream_str.split("cuda_stream=0x")[1].strip(">"), 16) except (IndexError, ValueError): stream_id = 0 # Default to 0 if we can't extract it - + print(f"Using PyTorch stream ID: {stream_id}") return (0, stream_id) # Return format required by CUDA Python + # Create a wrapper for the PyTorch stream pt_stream_wrapper = PyTorchStreamWrapper(pt_stream) # Initialize a CUDA Python Stream from the PyTorch stream s = Stream._init(obj=pt_stream_wrapper) -print(f"Successfully created CUDA Python stream from PyTorch stream") +print("Successfully created CUDA Python stream from PyTorch stream") # prepare program arch = "".join(f"{i}" for i in dev.compute_capability) @@ -68,9 +75,9 @@ def __cuda_stream__(self): # prepare input/output size = 64 # Use a single element tensor for 'a' -a = torch.tensor([10.0], dtype=dtype, device='cuda') -x = torch.rand(size, dtype=dtype, device='cuda') -y = torch.rand(size, dtype=dtype, device='cuda') +a = torch.tensor([10.0], dtype=dtype, device="cuda") +x = torch.rand(size, dtype=dtype, device="cuda") +y = torch.rand(size, dtype=dtype, device="cuda") out = torch.empty_like(x) # prepare launch @@ -96,9 +103,9 @@ def __cuda_stream__(self): # prepare input size = 128 # Use a single element tensor for 'a' -a = torch.tensor([42.0], dtype=dtype, device='cuda') -x = torch.rand(size, dtype=dtype, device='cuda') -y = torch.rand(size, dtype=dtype, device='cuda') +a = torch.tensor([42.0], dtype=dtype, device="cuda") +x = torch.rand(size, dtype=dtype, device="cuda") +y = torch.rand(size, dtype=dtype, device="cuda") # prepare output out = torch.empty_like(x) From e33c2c262ea8b3e04a23f0d6fe9ec8043dfd6d88 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 25 Apr 2025 14:20:32 -0700 Subject: [PATCH 03/16] simplify example --- cuda_core/examples/pytorch_example.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index 143d6d9727..9c5b02d352 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -36,18 +36,7 @@ def __init__(self, pt_stream): self.pt_stream = pt_stream def __cuda_stream__(self): - # Extract the stream ID from PyTorch's stream object - if hasattr(self.pt_stream, "cuda_stream"): - stream_id = self.pt_stream.cuda_stream - else: - # Try to extract from string representation - stream_str = str(self.pt_stream) - try: - stream_id = int(stream_str.split("cuda_stream=0x")[1].strip(">"), 16) - except (IndexError, ValueError): - stream_id = 0 # Default to 0 if we can't extract it - - print(f"Using PyTorch stream ID: {stream_id}") + stream_id = self.pt_stream.cuda_stream return (0, stream_id) # Return format required by CUDA Python From aa58a6e1397365dcb0c8eaf4c270882e7a3bf019 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 25 Apr 2025 14:22:03 -0700 Subject: [PATCH 04/16] signoff Signed-off-by: Mark Saroufim From 1378e6a93470d360e6f1a01e4826501bed05ae99 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 11:37:11 -0700 Subject: [PATCH 05/16] test suite changes --- cuda_core/tests/example_tests/utils.py | 2 +- cuda_core/tests/requirements-cu11.txt | 1 + cuda_core/tests/requirements-cu12.txt | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cuda_core/tests/example_tests/utils.py b/cuda_core/tests/example_tests/utils.py index 3b8e9b2365..229851a2f7 100644 --- a/cuda_core/tests/example_tests/utils.py +++ b/cuda_core/tests/example_tests/utils.py @@ -37,7 +37,7 @@ def run_example(samples_path, filename, env=None): exec(script, env if env else {}) # nosec B102 except ImportError as e: # for samples requiring any of optional dependencies - for m in ("cupy",): + for m in ("cupy", "torch"): if f"No module named '{m}'" in str(e): pytest.skip(f"{m} not installed, skipping related tests") break diff --git a/cuda_core/tests/requirements-cu11.txt b/cuda_core/tests/requirements-cu11.txt index d9bd566c76..ce97a6681f 100644 --- a/cuda_core/tests/requirements-cu11.txt +++ b/cuda_core/tests/requirements-cu11.txt @@ -2,3 +2,4 @@ pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda11x; python_version < "3.13" nvidia-cuda-runtime-cu11 # headers consumed by CuPy +torch # For PyTorch example diff --git a/cuda_core/tests/requirements-cu12.txt b/cuda_core/tests/requirements-cu12.txt index 18f6736033..cc0c1b76d5 100644 --- a/cuda_core/tests/requirements-cu12.txt +++ b/cuda_core/tests/requirements-cu12.txt @@ -2,3 +2,4 @@ pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda12x; python_version < "3.13" nvidia-cuda-runtime-cu12 # headers consumed by CuPy +torch # For PyTorch example From 4609a2f341df3e214c7622863fbc1081eebd4ae2 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 12:13:16 -0700 Subject: [PATCH 06/16] Update cuda_core/examples/pytorch_example.py Co-authored-by: Keith Kraus --- cuda_core/examples/pytorch_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index 9c5b02d352..471064aa2c 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE From c265c2ae4eed0dda448102a35a0da335ecf6b16c Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 12:13:23 -0700 Subject: [PATCH 07/16] Update cuda_core/examples/pytorch_example.py Co-authored-by: Keith Kraus --- cuda_core/examples/pytorch_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index 471064aa2c..1dd29fb06e 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # -# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +# SPDX-License-Identifier: Apache-2.0 ## Usage: pip install "cuda-core[cu12]" ## python python_example.py From 3917b40341e7997b92311d6cac45cfb47d79bd36 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 13:10:36 -0700 Subject: [PATCH 08/16] Update requirements-cu12.txt --- cuda_core/tests/requirements-cu12.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/tests/requirements-cu12.txt b/cuda_core/tests/requirements-cu12.txt index cc0c1b76d5..613e0c81f7 100644 --- a/cuda_core/tests/requirements-cu12.txt +++ b/cuda_core/tests/requirements-cu12.txt @@ -1,3 +1,4 @@ +--index-url https://download.pytorch.org/whl/cu126 pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda12x; python_version < "3.13" From 3bc93b0eb83aa6bc524538b75ff6267247e91e4e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 13:10:54 -0700 Subject: [PATCH 09/16] Update requirements-cu11.txt --- cuda_core/tests/requirements-cu11.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/tests/requirements-cu11.txt b/cuda_core/tests/requirements-cu11.txt index ce97a6681f..07c2647639 100644 --- a/cuda_core/tests/requirements-cu11.txt +++ b/cuda_core/tests/requirements-cu11.txt @@ -1,3 +1,4 @@ + --index-url https://download.pytorch.org/whl/cu118 pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda11x; python_version < "3.13" From cef69e5fa978b994abf18d3f899386225c51158e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 14:07:47 -0700 Subject: [PATCH 10/16] Update requirements-cu12.txt Co-authored-by: Keith Kraus --- cuda_core/tests/requirements-cu12.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/tests/requirements-cu12.txt b/cuda_core/tests/requirements-cu12.txt index 613e0c81f7..2cfe6ec317 100644 --- a/cuda_core/tests/requirements-cu12.txt +++ b/cuda_core/tests/requirements-cu12.txt @@ -1,4 +1,4 @@ ---index-url https://download.pytorch.org/whl/cu126 +--extra-index-url https://download.pytorch.org/whl/cu126 pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda12x; python_version < "3.13" From f840546a0b2c2ce3f26c1e7523623367c490313b Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 14:07:53 -0700 Subject: [PATCH 11/16] Update requirements-cu11.txt Co-authored-by: Keith Kraus --- cuda_core/tests/requirements-cu11.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/tests/requirements-cu11.txt b/cuda_core/tests/requirements-cu11.txt index 07c2647639..cc618bdfd2 100644 --- a/cuda_core/tests/requirements-cu11.txt +++ b/cuda_core/tests/requirements-cu11.txt @@ -1,4 +1,4 @@ - --index-url https://download.pytorch.org/whl/cu118 + --extra-index-url https://download.pytorch.org/whl/cu118 pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda11x; python_version < "3.13" From 5f68dbe0979fc739c33097e8c902866d7b74b139 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 14:08:38 -0700 Subject: [PATCH 12/16] Update pytorch_example.py Co-authored-by: Leo Fang --- cuda_core/examples/pytorch_example.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index 1dd29fb06e..02fafb3879 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -40,12 +40,7 @@ def __cuda_stream__(self): return (0, stream_id) # Return format required by CUDA Python -# Create a wrapper for the PyTorch stream -pt_stream_wrapper = PyTorchStreamWrapper(pt_stream) - -# Initialize a CUDA Python Stream from the PyTorch stream -s = Stream._init(obj=pt_stream_wrapper) -print("Successfully created CUDA Python stream from PyTorch stream") +s = PyTorchStreamWrapper(pt_stream) # prepare program arch = "".join(f"{i}" for i in dev.compute_capability) From 66964ee1f6d5a17ddf5bae08a2f2cca8181294b9 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 14:08:46 -0700 Subject: [PATCH 13/16] Update pytorch_example.py Co-authored-by: Leo Fang --- cuda_core/examples/pytorch_example.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index 02fafb3879..a58627d943 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -73,9 +73,6 @@ def __cuda_stream__(self): # launch kernel on our stream launch(s, config, ker, *ker_args) -# Wait for our CUDA kernel to complete -s.sync() - # check result assert torch.allclose(out, a.item() * x + y) print("Single precision test passed!") From a750cdc1b909951964e668cdf0b752aaca42c09f Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 14:42:12 -0700 Subject: [PATCH 14/16] Update cuda_core/examples/pytorch_example.py Co-authored-by: Leo Fang --- cuda_core/examples/pytorch_example.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index a58627d943..bd85946ebe 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -100,9 +100,6 @@ def __cuda_stream__(self): # launch kernel on PyTorch's stream launch(s, config, ker, *ker_args) -# Wait for our CUDA kernel to complete -s.sync() - # check result assert torch.allclose(out, a.item() * x + y) print("Double precision test passed!") From 7113750affce1a2c95388f0717507413ab557949 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 29 Apr 2025 14:43:46 -0700 Subject: [PATCH 15/16] remove .item() call --- cuda_core/examples/pytorch_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/examples/pytorch_example.py b/cuda_core/examples/pytorch_example.py index bd85946ebe..76a9728594 100644 --- a/cuda_core/examples/pytorch_example.py +++ b/cuda_core/examples/pytorch_example.py @@ -8,7 +8,7 @@ import torch -from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, Stream, launch +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch # SAXPY kernel - passing a as a pointer to avoid any type issues code = """ @@ -101,6 +101,6 @@ def __cuda_stream__(self): launch(s, config, ker, *ker_args) # check result -assert torch.allclose(out, a.item() * x + y) +assert torch.allclose(out, a * x + y) print("Double precision test passed!") print("All tests passed successfully!") From 7d3582f5b0ebfd8c0c1b16acad25727380036c31 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Tue, 29 Apr 2025 18:50:24 -0400 Subject: [PATCH 16/16] Defer CI setup to later --- cuda_core/tests/requirements-cu11.txt | 2 -- cuda_core/tests/requirements-cu12.txt | 2 -- 2 files changed, 4 deletions(-) diff --git a/cuda_core/tests/requirements-cu11.txt b/cuda_core/tests/requirements-cu11.txt index cc618bdfd2..d9bd566c76 100644 --- a/cuda_core/tests/requirements-cu11.txt +++ b/cuda_core/tests/requirements-cu11.txt @@ -1,6 +1,4 @@ - --extra-index-url https://download.pytorch.org/whl/cu118 pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda11x; python_version < "3.13" nvidia-cuda-runtime-cu11 # headers consumed by CuPy -torch # For PyTorch example diff --git a/cuda_core/tests/requirements-cu12.txt b/cuda_core/tests/requirements-cu12.txt index 2cfe6ec317..18f6736033 100644 --- a/cuda_core/tests/requirements-cu12.txt +++ b/cuda_core/tests/requirements-cu12.txt @@ -1,6 +1,4 @@ ---extra-index-url https://download.pytorch.org/whl/cu126 pytest # TODO: remove this hack once cupy has a cp313 build cupy-cuda12x; python_version < "3.13" nvidia-cuda-runtime-cu12 # headers consumed by CuPy -torch # For PyTorch example