Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions python/tvm/script/ir_builder/tir/external_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
import json
import logging
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

from tvm import __version__ as tvm_version
from tvm import tir
from tvm.runtime import Module, load_module
from tvm.runtime import Module, load_module, const
from tvm.contrib import nvcc


class BaseKernel:
class BaseKernel: # pylint: disable=too-few-public-methods
"""Base class for external kernels."""

def compile_to_device_module(
Expand Down Expand Up @@ -91,6 +93,60 @@ def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_n
return kernel_module


class SourceKernel(BaseKernel): # pylint: disable=too-few-public-methods
"""A kernel from source code."""

def __init__(self, source_code: str):
self.source_code = source_code

def compile_to_device_module( # pylint: disable=arguments-differ
self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any], **kwargs: Dict[str, Any]
) -> Tuple[str, Module, List[Any]]:
"""Compile the kernel to a device module."""
from tvm.relax.frontend.nn import SourceModule # pylint: disable=import-outside-toplevel

kernel_name = kwargs["kernel_name"]
assert len(grid) == 2, (
"grid should be two list of integers, representing the dimension of "
"['blockIdx.x', 'blockIdx.y', 'blockIdx.z'] and "
"['threadIdx.x', 'threadIdx.y', 'threadIdx.z']"
)
assert isinstance(grid[0], (list, tuple)) and isinstance(grid[1], (list, tuple))
launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][: len(grid[0])] + [
"threadIdx.x",
"threadIdx.y",
"threadIdx.z",
][: len(grid[1])]
runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg in args]
kernel_arg_types = [arg.dtype for arg in runtime_args]
runtime_args = runtime_args + list(grid[0]) + list(grid[1])

# Reuse compilation path from SourceModule
compile_options = SourceModule.get_compile_options("cu")
source_code = self.source_code
try:
source_path = Path(source_code)
if source_path.is_file():
with open(source_path, "r") as f:
source_code = f.read()
except: # pylint: disable=bare-except
pass

with tempfile.TemporaryDirectory() as temp_dir:
ptx_path = f"{temp_dir}/{kernel_name}.ptx"
nvcc.compile_cuda(
source_code, target_format="ptx", options=compile_options, path_target=ptx_path
)
with open(ptx_path, "r") as f:
ptx = f.read()

kernel_module = self._create_cuda_module(
ptx, kernel_arg_types, launch_param_tags, kernel_name
)

return kernel_name, kernel_module, runtime_args


def call_kernel(
kernel,
launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]],
Expand Down Expand Up @@ -123,6 +179,8 @@ def call_kernel(
from .triton import TritonKernel # pylint: disable=import-outside-toplevel

kernel = TritonKernel(kernel)
elif kernel_type == "builtins.str":
kernel = SourceKernel(kernel)
else:
raise ValueError("Unsupported kernel type {}".format(kernel_type))

Expand Down
100 changes: 100 additions & 0 deletions tests/python/relax/test_tir_call_source_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np

import tvm
import tvm.testing
from tvm import relax
from tvm.script import tir as T, ir as I, relax as R

add_cuda_source = """
extern "C" __global__ void add_kernel(float* x, float* y, float* output, int n_elements) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n_elements) {
output[i] = x[i] + y[i];
}
}
"""


@tvm.testing.requires_cuda
def test_tir_call_source_kernel():
@I.ir_module
class Module:
@T.prim_func
def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None:
T.func_attr({"global_symbol": "add"})
m = T.int64()
x = T.match_buffer(x_handle, (m,), "float32")
y = T.match_buffer(y_handle, (m,), "float32")
output = T.match_buffer(output_handle, (m,), "float32")
with T.block("root"):
T.reads(x[0:m], y[0:m])
T.writes(output[0:m])
BLOCK_SIZE = T.meta_var(64)
T.call_kernel(
add_cuda_source,
((T.ceildiv(m, BLOCK_SIZE),), (BLOCK_SIZE,)),
x.data,
y.data,
output.data,
m,
kernel_name="add_kernel",
)

@R.function
def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")):
m = T.int64()
with R.dataflow():
output = R.call_tir(Module.add, [x, y], relax.TensorStructInfo((m,), "float32"))
R.output(output)
return output

@I.ir_module
class Parsed:
@T.prim_func
def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle):
m = T.int64()
x = T.match_buffer(x_handle, (m,))
y = T.match_buffer(y_handle, (m,))
output = T.match_buffer(output_handle, (m,))
with T.block("root"):
T.reads(x[0:m], y[0:m])
T.writes(output[0:m])
T.call_packed(
"add_kernel",
x.data,
y.data,
output.data,
m,
(m + T.int64(64) - T.int64(1)) // T.int64(64),
64,
)

tvm.ir.assert_structural_equal(Module["add"], Parsed["add"])
assert len(Module.get_attr("external_mods")) == 1

device = tvm.cuda(0)
x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device)
y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device)
output_np = x_nd.numpy() + y_nd.numpy()

with tvm.target.Target("cuda"):
lib = relax.build(Module)
output_nd = tvm.runtime.relax_vm.VirtualMachine(lib, device)["main"](x_nd, y_nd)
tvm.testing.assert_allclose(output_nd.numpy(), output_np, rtol=1e-5)