From a4f2a841eb9003610d09fb0f021fab2bd7ab1ab1 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 1 Oct 2024 12:22:19 -0700 Subject: [PATCH 1/3] [TVMScript][TIR] Add source kernel intetration via call_kernel --- .../script/ir_builder/tir/external_kernel.py | 62 +++++++++++++- .../relax/test_tir_call_source_kernel.py | 82 +++++++++++++++++++ 2 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 tests/python/relax/test_tir_call_source_kernel.py diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py b/python/tvm/script/ir_builder/tir/external_kernel.py index 8c2467fad330..405e1e6cbf93 100644 --- a/python/tvm/script/ir_builder/tir/external_kernel.py +++ b/python/tvm/script/ir_builder/tir/external_kernel.py @@ -18,14 +18,16 @@ import json import logging import tempfile +from pathlib import Path from typing import Any, Dict, List, Tuple, Union from tvm import __version__ as tvm_version from tvm import tir -from tvm.runtime import Module, load_module +from tvm.runtime import Module, load_module, const +from tvm.contrib import nvcc -class BaseKernel: +class BaseKernel: # pylint: disable=too-few-public-methods """Base class for external kernels.""" def compile_to_device_module( @@ -91,6 +93,60 @@ def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_n return kernel_module +class SourceKernel(BaseKernel): # pylint: disable=too-few-public-methods + """A kernel from source code.""" + + def __init__(self, source_code: str): + self.source_code = source_code + + def compile_to_device_module( # pylint: disable=arguments-differ + self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any], **kwargs: Dict[str, Any] + ) -> Tuple[str, Module, List[Any]]: + """Compile the kernel to a device module.""" + from tvm.relax.frontend.nn import SourceModule # pylint: disable=import-outside-toplevel + + kernel_name = kwargs["kernel_name"] + assert len(grid) == 2, ( + "grid should be two list of integers, representing the dimension of " + "['blockIdx.x', 'blockIdx.y', 'blockIdx.z'] and " + "['threadIdx.x', 'threadIdx.y', 'threadIdx.z']" + ) + assert isinstance(grid[0], (list, tuple)) and isinstance(grid[1], (list, tuple)) + launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][: len(grid[0])] + [ + "threadIdx.x", + "threadIdx.y", + "threadIdx.z", + ][: len(grid[1])] + runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg in args] + kernel_arg_types = [arg.dtype for arg in runtime_args] + runtime_args = runtime_args + list(grid[0]) + list(grid[1]) + + # Reuse compilation path from SourceModule + compile_options = SourceModule.get_compile_options("cu") + source_code = self.source_code + try: + source_path = Path(source_code) + if source_path.is_file(): + with open(source_path, "r") as f: + source_code = f.read() + except: # pylint: disable=bare-except + pass + + with tempfile.TemporaryDirectory() as temp_dir: + ptx_path = f"{temp_dir}/{kernel_name}.ptx" + nvcc.compile_cuda( + source_code, target_format="ptx", options=compile_options, path_target=ptx_path + ) + with open(ptx_path, "r") as f: + ptx = f.read() + + kernel_module = self._create_cuda_module( + ptx, kernel_arg_types, launch_param_tags, kernel_name + ) + + return kernel_name, kernel_module, runtime_args + + def call_kernel( kernel, launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]], @@ -123,6 +179,8 @@ def call_kernel( from .triton import TritonKernel # pylint: disable=import-outside-toplevel kernel = TritonKernel(kernel) + elif kernel_type == "builtins.str": + kernel = SourceKernel(kernel) else: raise ValueError("Unsupported kernel type {}".format(kernel_type)) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py new file mode 100644 index 000000000000..048ef40f8837 --- /dev/null +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -0,0 +1,82 @@ +import numpy as np + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import tir as T, ir as I, relax as R + +add_cuda_source = """ +extern "C" __global__ void add_kernel(float* x, float* y, float* output, int n_elements) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_elements) { + output[i] = x[i] + y[i]; + } +} +""" + +@tvm.testing.requires_cuda +def test_tir_call_source_kernel(): + @I.ir_module + class Module: + @T.prim_func + def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None: + T.func_attr({"global_symbol": "add"}) + m = T.int64() + x = T.match_buffer(x_handle, (m,), "float32") + y = T.match_buffer(y_handle, (m,), "float32") + output = T.match_buffer(output_handle, (m,), "float32") + with T.block("root"): + T.reads(x[0:m], y[0:m]) + T.writes(output[0:m]) + BLOCK_SIZE = T.meta_var(64) + T.call_kernel( + add_cuda_source, + ((T.ceildiv(m, BLOCK_SIZE),), (BLOCK_SIZE,)), + x.data, + y.data, + output.data, + m, + kernel_name="add_kernel", + ) + + @R.function + def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")): + m = T.int64() + with R.dataflow(): + output = R.call_tir(Module.add, [x, y], relax.TensorStructInfo((m,), "float32")) + R.output(output) + return output + + @I.ir_module + class Parsed: + @T.prim_func + def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): + m = T.int64() + x = T.match_buffer(x_handle, (m,)) + y = T.match_buffer(y_handle, (m,)) + output = T.match_buffer(output_handle, (m,)) + with T.block("root"): + T.reads(x[0:m], y[0:m]) + T.writes(output[0:m]) + T.call_packed( + "add_kernel", + x.data, + y.data, + output.data, + m, + (m + T.int64(64) - T.int64(1)) // T.int64(64), + 64, + ) + + tvm.ir.assert_structural_equal(Module["add"], Parsed["add"]) + assert len(Module.get_attr("external_mods")) == 1 + + device = tvm.cuda(0) + x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + output_np = x_nd.numpy() + y_nd.numpy() + + with tvm.target.Target("cuda"): + lib = relax.build(Module) + output_nd = tvm.runtime.relax_vm.VirtualMachine(lib, device)["main"](x_nd, y_nd) + tvm.testing.assert_allclose(output_nd.numpy(), output_np, rtol=1e-5) From d531c7646027272ba0dc1915363d0747d452c857 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 2 Oct 2024 10:51:38 -0700 Subject: [PATCH 2/3] lint --- tests/python/relax/test_tir_call_source_kernel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index 048ef40f8837..a370aba6e22a 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -14,6 +14,7 @@ } """ + @tvm.testing.requires_cuda def test_tir_call_source_kernel(): @I.ir_module From 5f71ec92e47fa325370af5e88eca32ec7827a198 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 2 Oct 2024 19:40:10 -0700 Subject: [PATCH 3/3] lint --- .../python/relax/test_tir_call_source_kernel.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index a370aba6e22a..9a877ad35f8f 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import numpy as np import tvm