From 280c0b24ea2f1f3f47fd341c2f4fbcee79c89157 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 15 Apr 2024 13:23:06 -0700 Subject: [PATCH] Add check for matmul dtype and fix reduction rule --- python/tvm/dlight/gpu/matmul.py | 3 ++- python/tvm/dlight/gpu/reduction.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 0f224b89f9e4..73c87cb2ff81 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -841,9 +841,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if apply_tensorization: # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + tensorize_sch = None if in_dtype == "int8" and out_dtype == "int32": tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) - else: + elif in_dtype == "float16" and out_dtype in ["float16", "float32"]: tensorize_sch = MatmulTensorization().apply(func, target, _) if tensorize_sch is not None: return tensorize_sch diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 4cc142ab1614..fc63e4836849 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -16,17 +16,17 @@ # under the License. """A rule for reduction. """ # TODO: combine reduction rule and general reduction rule into one file. -from typing import List, Optional, Tuple, Union +from typing import List, Mapping, Optional, Tuple, Union from tvm import arith, ir, tir from tvm.target import Target from ..base import ( BlockInfo, - normalize_prim_func, - try_inline_contiguous_spatial, detect_dominant_read, is_broadcast_epilogue, + normalize_prim_func, + try_inline_contiguous_spatial, ) from . import utils from .base import GPUScheduleRule @@ -111,9 +111,9 @@ def _normalize( # pylint: disable=too-many-branches sch: tir.Schedule, block_info: BlockInfo, access: arith.IterSumExpr, - ) -> Tuple[Optional[bool], Optional[int]]: + ) -> Tuple[Optional[bool], Optional[int], Optional[Mapping[int, int]], Optional[int]]: if access.base != 0: - return None, None + return None, None, None, None iter_to_info = {i.var: i for i in block_info.iters} s_loops, r_loops, c_loops, c_factor = [], [], [], None s_split_loop, s_split_index = None, None @@ -124,7 +124,7 @@ def _normalize( # pylint: disable=too-many-branches is_inner_reduction = info.kind == "R" if split_expr.lower_factor > 1: if c_loops: - return None, None + return None, None, None, None s_split_loop = loop s_split_index = len(s_loops) loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) @@ -141,7 +141,7 @@ def _normalize( # pylint: disable=too-many-branches if info.kind == "S" and info.dom == 1: s_loops.append(info.loop_rv) else: - return None, None + return None, None, None, None loop_order = {} s_block_var_loops = [] @@ -161,7 +161,7 @@ def _normalize( # pylint: disable=too-many-branches assert s_loops assert r_loops if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]): - return None, None + return None, None, None, None if not c_loops: c_loops = [sch.add_unit_loop(block_info.block_rv)] sch.reorder(*s_loops, *r_loops, *c_loops)