Skip to content
4 changes: 2 additions & 2 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ if(USE_CUDA)

if(USE_CUBLAS)
message(STATUS "Build with cuBLAS support")
tvm_file_glob(GLOB CUBLAS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc)
list(APPEND COMPILER_SRCS ${CUBLAS_RELAY_CONTRIB_SRC})
tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc)
list(APPEND COMPILER_SRCS ${CUBLAS_CONTRIB_SRC})
tvm_file_glob(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc)
list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY})
Expand Down
19 changes: 3 additions & 16 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,22 +560,9 @@ def _extract_relax_function_signature(f):


def _extract_arg_idx(pattern_name, f):
pattern_entry = relax.backend.get_pattern(pattern_name)
if pattern_entry is None:
raise ValueError(f"Unsupported op_type {pattern_name}")
var2val = relax.analysis.get_var2val(f)
matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, var2val)

func_args = list(f.params)

arg_idx = {}
for name, annotation_pattern in pattern_entry.annotation_patterns.items():
arg_expr = matched_expr[annotation_pattern]
if arg_expr not in func_args:
continue
arg_idx[name] = func_args.index(arg_expr)

return arg_idx
extract_func = tvm.get_global_func("relax.contrib.extract_arg_idx")
arg_indices = extract_func(pattern_name, f)
return {k: int(v) for k, v in arg_indices.items()}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yelite this has been ported to cpp



def is_shape_valid_for_cutlass_matmul(
Expand Down
154 changes: 154 additions & 0 deletions python/tvm/relax/backend/contrib/cublas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Pattern table for cuBLAS backend"""
import operator
from functools import reduce

import tvm
from tvm.relax import transform
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_matmul_pattern


def _is_supported_dtype(lhs_dtype, rhs_dtype):
"""Check if dtypes in the given workload are supported by cuBLAS BYOC."""
return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
lhs_dtype == "float32" and rhs_dtype == "float32"
)


def _check_matmul(context: PatternCheckContext) -> bool:
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]

lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
if not _is_supported_dtype(lhs_dtype, rhs_dtype):
return False

lhs_shape = lhs.struct_info.shape.values
rhs_shape = rhs.struct_info.shape.values

if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
# Reduction axis must be constant
return False

lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)

# cuBLASLt does not seem to support batched GEMM with one of matrices having
# one batch (with batch_stride 0). So for batched GEMM, the two batch counts
# must be equal.
return (
(lhs_batches == 1 and rhs_batches == 1)
or isinstance(lhs_batches, tvm.tir.Var)
or isinstance(rhs_batches, tvm.tir.Var)
or (int(lhs_batches) == int(rhs_batches))
)


register_patterns(
[
(
"cublas.matmul",
*make_matmul_pattern(
with_bias=False,
),
_check_matmul,
),
(
"cublas.matmul_bias",
*make_matmul_pattern(
with_bias=True,
),
_check_matmul,
),
(
"cublas.matmul_bias_relu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
),
_check_matmul,
),
(
"cublas.matmul_bias_gelu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
),
_check_matmul,
),
(
"cublas.matmul_transposed",
*make_matmul_pattern(
with_bias=False,
transposed_rhs=True,
),
_check_matmul,
),
(
"cublas.matmul_transposed_bias",
*make_matmul_pattern(
with_bias=True,
transposed_rhs=True,
),
_check_matmul,
),
(
"cublas.matmul_transposed_bias_relu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
transposed_rhs=True,
),
_check_matmul,
),
(
"cublas.matmul_transposed_bias_gelu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
transposed_rhs=True,
),
_check_matmul,
),
]
)


def partition_for_cublas(mod):
"""
Partition the input module into cuBLAS-supported subgraphs.

Parameters
----------
mod: tvm.IRModule
The IRModule to be partitioned.

Returns
-------
mod: tvm.IRModule
The resulting IRModule, containing partitioned subgraphs to be
offloaded to the cuBLAS backend.
"""

patterns = get_patterns_with_prefix("cublas")
return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod)
15 changes: 2 additions & 13 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

"""Pattern table for CUTLASS backend"""

from typing import Mapping, Optional, Sequence, Tuple
from typing import Mapping, Sequence

import tvm
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
from tvm.relax import DataflowVar, ShapeExpr, Var, transform
from tvm.relax import DataflowVar, Var, transform
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
Expand All @@ -33,16 +32,6 @@
)


def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
result = []
for dim in shape.values:
if isinstance(dim, tvm.tir.expr.IntImm):
result.append(int(dim))
else:
return None
return result


def _is_supported_dtype(lhs_dtype, rhs_dtype):
"""Check if dtypes in the given workload are supported by CUTLASS."""
return (
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .nn import *
from .relay_translator import *
from .ast_printer import dump_ast
from .matmul import *
66 changes: 66 additions & 0 deletions python/tvm/relax/testing/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities to construct matmul workloads."""
import tvm
from tvm.script import relax as R
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder


def get_relax_matmul_module(
x_shape,
y_shape,
dtype,
transposed_y=False,
with_bias=False,
activation=None,
residual_bin_op=None,
residual_activation=None,
):
"""Create a matmul op followd by epilogue operations."""
if transposed_y:
n = y_shape[-2]
else:
n = y_shape[-1]

with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
x = R.arg("x", R.Tensor(x_shape, dtype))
y = R.arg("y", R.Tensor(y_shape, dtype))
if with_bias:
bias = R.arg("bias", R.Tensor((n,), dtype))

with R.dataflow() as frame:
if transposed_y:
axes = list(range(len(y_shape) - 2)) + [-1, -2]
y = R.emit(R.permute_dims(y, axes=axes))
result = R.emit(R.matmul(x, y, out_dtype=dtype))
if with_bias:
result = R.emit(result + bias)
if activation is not None:
result = R.emit(activation(result))
if residual_bin_op is not None:
result = R.emit(residual_bin_op(result, x))
if residual_activation is not None:
result = R.emit(residual_activation(result))
R.output(result)

R.func_ret_value(frame.output_vars[0])

func = builder.get()
return tvm.IRModule({"main": func})
Loading