From 7d3889ddbb2be07b6621c315cb664195bed9190d Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Mon, 17 Feb 2025 17:15:14 -0500 Subject: [PATCH 1/6] [Dlight][CPU] Add CPU Backend Support for GEMV Optimization This PR adds Dlight CPU support with optimized GEMV scheduling, including pattern detection, loop tiling, vectorization, and parallel execution. It improves maintainability by refining target checks, reduction handling, and scheduling logic. CPU: AMD Ryzen 9 7950X 16-Core Processor MODEL: Qwen2-0.5B-q4f16_1-MLC Prompt: What is the meaning of life? Results: Baseline: prompt_tokens=27 completion_tokens=235 total_tokens=262 extra={'prompt_tokens': 27, 'completion_tokens': 235, 'prefill_tokens': 27, 'decode_tokens': 234, 'jump_forward_tokens': 0, 'prefill_tokens_per_s': 0.9777329325367138, 'decode_tokens_per_s': 0.558195154052001, 'end_to_end_latency_s': 446.823128383, 'ttft_s': 27.614902906, 'inter_token_latency_s': 1.9013750143957446} Optimized: usage: prompt_tokens=27 completion_tokens=227 total_tokens=254 extra={'prompt_tokens': 27, 'completion_tokens': 227, 'prefill_tokens': 27, 'decode_tokens': 226, 'jump_forward_tokens': 0, 'prefill_tokens_per_s': 1.0010420333327994, 'decode_tokens_per_s': 2.9349053824023454, 'end_to_end_latency_s': 103.976080401, 'ttft_s': 26.971894387, 'inter_token_latency_s': 0.4580444070528635} --- python/tvm/dlight/cpu/__init__.py | 20 ++ python/tvm/dlight/cpu/base.py | 40 +++ python/tvm/dlight/cpu/gemv.py | 255 ++++++++++++++++++ python/tvm/dlight/cpu/utils.py | 43 +++ python/tvm/relax/frontend/nn/llm/tree_attn.py | 2 + 5 files changed, 360 insertions(+) create mode 100644 python/tvm/dlight/cpu/__init__.py create mode 100644 python/tvm/dlight/cpu/base.py create mode 100644 python/tvm/dlight/cpu/gemv.py create mode 100644 python/tvm/dlight/cpu/utils.py diff --git a/python/tvm/dlight/cpu/__init__.py b/python/tvm/dlight/cpu/__init__.py new file mode 100644 index 000000000000..3282275862f3 --- /dev/null +++ b/python/tvm/dlight/cpu/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +CPU-generic schedule rules. +""" +from .gemv import GEMV diff --git a/python/tvm/dlight/cpu/base.py b/python/tvm/dlight/cpu/base.py new file mode 100644 index 000000000000..4d16f9726bff --- /dev/null +++ b/python/tvm/dlight/cpu/base.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base schedule rule for CPU operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class CPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to CPU targets, will return None if the target is not CPU.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for gpu rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "llvm" == target.kind.name diff --git a/python/tvm/dlight/cpu/gemv.py b/python/tvm/dlight/cpu/gemv.py new file mode 100644 index 000000000000..b0ca12236302 --- /dev/null +++ b/python/tvm/dlight/cpu/gemv.py @@ -0,0 +1,255 @@ +from functools import reduce +from typing import List, Optional, Union + +from tvm.target import Target + +from ..base import ( + BlockInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, + try_inline_contiguous_spatial, +) + +from tvm import arith, ir, tir +from .base import CPUScheduleRule + +from tvm.target import Target + +from .utils import auto_vectorize, get_bytes, get_extent + + +def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: + """Check if the block is a GEMV. + + Parameters + ---------- + + sch : tir.Schedule + The schedule + + block_info : BlockInfo + The block info to be checked + + + Returns + ------- + ret : Optional[List[tir.Buffer]] + The vector buffers used in the GEMV if it is a GEMV, otherwise None. + """ + block = block_info.block_rv + block_stmt = sch.get(block) + conditions = [] + conditions.append(block_info.is_reduction()) + conditions.append(len(block_stmt.reads) >= 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(_get_reduction_expr(block_stmt) is not None) + conditions.append( + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0 + ) + if not all(conditions): + return None + + iter_num = len(block_stmt.iter_vars) + ret = [ + read.buffer + for read in block_stmt.reads + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 + ] + return ret if 0 < len(ret) < len(block_stmt.reads) else None + + +def normalize( + sch: tir.Schedule, + block_info: BlockInfo, +) -> Optional[bool]: + """Normalize the main block.""" + block_stmt: tir.Block = sch.get(block_info.block_rv) + access = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ) + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] + buffers_use_vars.extend( + [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ] + ) + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars + ): + return None + iter_to_info = {i.var: i for i in block_info.iters} + batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + inner_axis = access.args[-1].source.source + is_inner_reduction = iter_to_info[inner_axis].kind == "R" + + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.get(var) + loop = info.loop_rv + is_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we only support the reduction dim being grouped atm + if not is_reduction: + return None + c_loops.append(c_loop) + if is_reduction: + r_loops.append(loop) + elif all([var in buf_vars for buf_vars in buffers_use_vars]): + batch_loops.append(loop) + else: + s_loops.append(loop) + + assert s_loops + assert r_loops + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + if not batch_loops: + batch_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) + sch.fuse(*batch_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction + + +class GEMV(CPUScheduleRule): + """A rule for GEMV and DecodeGEMV.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None: + return None + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + block = block_info.block_rv + vector_input_buffers = is_gemv(sch, block_info) + if vector_input_buffers is None: + return None + + # Step 1. Normalize the block, merge spatial and reduction iters + is_inner_reduction = normalize(sch, block_info) + + # Step 2. Do the scheduling + if is_inner_reduction is None: + return None + elif is_inner_reduction: + return self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) + else: + # sch_outer reduction + return None + + def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the inner reduction block.""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + vector_width: int = 8, + parallel_threads: int = 8, + unroll_factor: int = 256 + ): + batch, s, r, c = sch.get_loops(block) + len_batch, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + len_S = len_batch * len_s + len_R = len_r * len_c + + if isinstance(len_S, int) and isinstance(len_R, int): + if len_S > len_R: + tile_s, tile_r = 128, 64 # Larger tiling for s-axis when len_S is larger + else: + tile_s, tile_r = 64, 128 # Larger tiling for r-axis when len_R is larger + else: + tile_s, tile_r = 64, 64 # Default tile sizes for unknown extents + + tile_c = min(vector_width, len_c) # Ensure c-axis tiling aligns with SIMD vector width + + # Apply loop tiling (improves cache locality) + s_outer, s_inner = sch.split(s, factors=[None, tile_s]) + r_outer, r_inner = sch.split(r, factors=[None, tile_r]) + c_outer, c_inner = sch.split(c, factors=[None, tile_c]) + + # Apply vectorization (SIMD optimization) + sch.vectorize(s_inner) # Vectorize computation along c-axis for AVX/NEON + + # Enable parallel execution + sch.parallel(s_outer) # Parallelize along the s-axis (major computation loop) + + # Apply loop unrolling for better CPU performance + sch.annotate(r_outer, "pragma_auto_unroll_max_step", unroll_factor) + sch.annotate(r_outer, "pragma_unroll_explicit", 1) + return sch + + return apply( + sch, + gemv=block, + ) diff --git a/python/tvm/dlight/cpu/utils.py b/python/tvm/dlight/cpu/utils.py new file mode 100644 index 000000000000..589a86453659 --- /dev/null +++ b/python/tvm/dlight/cpu/utils.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""Utility methods for generic CPU.""" +from typing import List, Optional, Union + +from tvm import DataType, tir +from tvm.target import Target + + +def get_bytes(dtype: Union[DataType, str]) -> int: + if isinstance(dtype, str): + dtype = DataType(dtype) + return dtype.itemsize() + + +def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): + loop: tir.For = sch.get(loop_rv) + return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent + + +def auto_vectorize(sch: tir.Schedule, loop: tir.schedule.LoopRV, max_vec: int): + """Auto vectorize the loop.""" + extent = get_extent(sch, loop) + if not isinstance(extent, int): + return + v = loop if extent <= max_vec else sch.split(loop, factors=[None, max_vec])[-1] + sch.vectorize(v) + diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 33614633fc77..9aa27ca83d70 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -762,6 +762,8 @@ def tree_attn_paged_kv_cpu( for h_qo in T.serial(h_q): for b_idx in T.serial(batch_size): with T.block("attn"): + T.reads() + T.writes() O_local = T.alloc_buffer((d, ), "float32") Q_local = T.alloc_buffer((d, ), "float32") K_local = T.alloc_buffer((d, ), "float32") From e09b1522b2b4cdcac961c61b9eba4cc904539808 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Mon, 17 Feb 2025 17:38:06 -0500 Subject: [PATCH 2/6] lint --- python/tvm/dlight/cpu/gemv.py | 52 ++++++++++++++++++---------------- python/tvm/dlight/cpu/utils.py | 4 +-- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/python/tvm/dlight/cpu/gemv.py b/python/tvm/dlight/cpu/gemv.py index b0ca12236302..8a3fa8b814ca 100644 --- a/python/tvm/dlight/cpu/gemv.py +++ b/python/tvm/dlight/cpu/gemv.py @@ -1,6 +1,23 @@ -from functools import reduce +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A rule for GEMV and DecodeGEMV.""" from typing import List, Optional, Union +from tvm import arith, ir, tir from tvm.target import Target from ..base import ( @@ -8,17 +25,11 @@ collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, detect_dominant_read, - is_broadcast_epilogue, normalize_prim_func, try_inline_contiguous_spatial, ) - -from tvm import arith, ir, tir from .base import CPUScheduleRule - -from tvm.target import Target - -from .utils import auto_vectorize, get_bytes, get_extent +from .utils import get_extent def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -79,7 +90,7 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe return ret if 0 < len(ret) < len(block_stmt.reads) else None -def normalize( +def normalize( # pylint: disable=too-many-locals, use-a-generator sch: tir.Schedule, block_info: BlockInfo, ) -> Optional[bool]: @@ -144,11 +155,11 @@ def normalize( class GEMV(CPUScheduleRule): """A rule for GEMV and DecodeGEMV.""" - def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - target: Target, - _: bool, + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements, no-else-return + self, + func: tir.PrimFunc, + target: Target, + _: bool, ) -> Union[None, tir.Schedule, List[tir.Schedule]]: if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None @@ -189,7 +200,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- # sch_outer reduction return None - def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + def sch_inner_reduction( # pylint: disable=too-many-arguments, too-many-positional-arguments, invalid-name, unused-argument self, sch: tir.Schedule, target: Target, @@ -199,19 +210,12 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, un ): """Schedule the inner reduction block.""" - def get_max_factor(n, factors): - factors = sorted(factors, reverse=True) - for factor in factors: - if n % factor == 0: - return factor - return 1 - - def apply( + def apply( # pylint: disable=unused-variable, too-many-locals sch: tir.Schedule, gemv, vector_width: int = 8, parallel_threads: int = 8, - unroll_factor: int = 256 + unroll_factor: int = 256, ): batch, s, r, c = sch.get_loops(block) len_batch, len_s, len_r, len_c = ( diff --git a/python/tvm/dlight/cpu/utils.py b/python/tvm/dlight/cpu/utils.py index 589a86453659..478baa89bf02 100644 --- a/python/tvm/dlight/cpu/utils.py +++ b/python/tvm/dlight/cpu/utils.py @@ -16,10 +16,9 @@ # under the License. # pylint: disable=missing-docstring """Utility methods for generic CPU.""" -from typing import List, Optional, Union +from typing import Union from tvm import DataType, tir -from tvm.target import Target def get_bytes(dtype: Union[DataType, str]) -> int: @@ -40,4 +39,3 @@ def auto_vectorize(sch: tir.Schedule, loop: tir.schedule.LoopRV, max_vec: int): return v = loop if extent <= max_vec else sch.split(loop, factors=[None, max_vec])[-1] sch.vectorize(v) - From 3314c61ac8854d74a1a7417f141dc05865e9993f Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Tue, 18 Feb 2025 15:33:08 -0500 Subject: [PATCH 3/6] Add unit test --- python/tvm/dlight/__init__.py | 1 + tests/python/dlight/test_cpu_gemv.py | 595 +++++++++++++++++++++++++++ 2 files changed, 596 insertions(+) create mode 100644 tests/python/dlight/test_cpu_gemv.py diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index 421d4017d1bd..4c895368fc82 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -16,6 +16,7 @@ # under the License. """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu +from . import cpu from .base import ( ApplyDefaultSchedule, BlockInfo, diff --git a/tests/python/dlight/test_cpu_gemv.py b/tests/python/dlight/test_cpu_gemv.py new file mode 100644 index 000000000000..c7eebf58aa9e --- /dev/null +++ b/tests/python/dlight/test_cpu_gemv.py @@ -0,0 +1,595 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import tir as T +from tvm.target import Target + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("llvm"): + return dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + + return transform + + +class TestGEMV(BaseBeforeAfter): + # fmt: off + + @T.prim_func + def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") + lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + var_T_divide_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + for i0, i1, i2, i3, k in T.grid(1, 32, 1, n, 128): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1637[v_i0, v_i1, v_i2, v_k], lv1638[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1637[v_i0, v_i1, v_i2, v_k] * lv1638[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) + for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(1, 32, 1, n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") + lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + for ax0_fused in range(32): + for ax1_fused_0 in T.parallel((n + 63) // 64): + for ax1_fused_1 in T.vectorized(64): + for ax2_fused_0 in T.serial(2, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_1, u_0, u_1 in T.grid(64, 1, 1): + with T.block("NT_matmul"): + v0 = T.axis.spatial(32, ax0_fused) + v1 = T.axis.spatial(n, ax1_fused_0 * 64 + ax1_fused_1) + v2 = T.axis.reduce(128, ax2_fused_0 * 64 + ax2_fused_1) + T.where(ax1_fused_0 * 64 + ax1_fused_1 < n) + T.reads(lv1637[0, v0, 0, v2], lv1638[0, v0, v1, v2]) + T.writes(var_NT_matmul_intermediate[0, v0, 0, v1]) + with T.init(): + var_NT_matmul_intermediate[0, v0, 0, v1] = T.float16(0.0) + var_NT_matmul_intermediate[0, v0, 0, v1] = var_NT_matmul_intermediate[0, v0, 0, v1] + lv1637[0, v0, 0, v2] * lv1638[0, v0, v1, v2] + for ax0, ax1 in T.grid(32, n): + with T.block("compute"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(var_NT_matmul_intermediate[0, v0, 0, v1], lv1614[0, 0, 0, v1]) + T.writes(var_compute_intermediate[0, v0, 0, v1]) + var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate[0, v0, 0, v1] * T.float16(0.088397790055248615), T.float16(-65504.0)), lv1614[0, 0, 0, v1])) + + # fmt: on + + +def test_decode_gemv_256_threads(): + # fmt: off + @T.prim_func(private=True) + def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") + for i, j in T.grid(22016, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) + T.writes(p_output0_intermediate[v_i, v_j]) + p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] + for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] + + @T.prim_func(private=True) + def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for u_fused in range(1): + for ax0_fused_0 in T.parallel(172): + for ax0_fused_1 in T.vectorized(128): + for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): + with T.block("NT_matmul"): + v0 = T.axis.spatial(22016, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) + T.reads(lv1654[0, 0, v1], lv571[v0, v1 // 8], lv572[v0, v1 // 32]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv1654[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv572[v0, v1 // 32]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv1(): + # fmt: off + + @T.prim_func(private=True) + def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") + for i, j in T.grid(22016, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) + T.writes(p_output0_intermediate[v_i, v_j]) + p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] + for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] + + @T.prim_func(private=True) + def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for u_fused in range(1): + for ax0_fused_0 in T.parallel(172): + for ax0_fused_1 in T.vectorized(128): + for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): + with T.block("NT_matmul"): + v0 = T.axis.spatial(22016, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) + T.reads(lv1654[0, 0, v1], lv571[v0, v1 // 8], lv572[v0, v1 // 32]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv1654[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv572[v0, v1 // 32]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv2(): + # fmt: off + + @T.prim_func(private=True) + def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((32000, 4096), "float16") + var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") + for i, j in T.grid(32000, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv771[v_i, v_j // 8], lv772[v_i, v_j // 32]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v_i, v_j // 32] + for i0, i1, i2, k in T.grid(1, 1, 32000, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv3216[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv3216[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] + for i0, i1, i2 in T.grid(1, 1, 32000): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func(private=True) + def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") + for u_fused in range(1): + for ax0_fused_0 in T.parallel(250): + for ax0_fused_1 in T.vectorized(128): + for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): + with T.block("NT_matmul"): + v0 = T.axis.spatial(32000, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) + T.reads(lv3216[0, 0, v1], lv771[v0, v1 // 8], lv772[v0, v1 // 32]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv3216[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv772[v0, v1 // 32]) + for ax0 in range(32000): + with T.block("compute"): + v0 = T.axis.spatial(32000, ax0) + T.reads(var_NT_matmul_intermediate[0, 0, v0]) + T.writes(p_output0_intermediate[0, 0, v0]) + p_output0_intermediate[0, 0, v0] = T.Cast("float32", var_NT_matmul_intermediate[0, 0, v0]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv3(): + # fmt: off + + @T.prim_func(private=True) + def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j // T.int64(32)]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @T.prim_func(private=True) + def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for u_fused in range(1): + for ax0_fused_0 in T.parallel(T.int64(64)): + for ax0_fused_1 in T.vectorized(T.int64(64)): + for ax1_0_fused_0 in T.serial(T.int64(11), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(T.int64(128), T.int64(1), T.int64(8)): + with T.block("NT_matmul"): + v0 = T.axis.spatial(T.int64(4096), ax0_fused_0 * T.int64(64) + ax0_fused_1) + v1 = T.axis.reduce(T.int64(11008), (ax1_0_fused_0 * T.int64(128) + ax1_0_fused_1) * T.int64(8) + ax1_1_0 * T.int64(8) + ax1_1_1) + T.where(ax1_0_fused_0 * T.int64(128) + ax1_0_fused_1 < T.int64(1376)) + T.reads(lv574[T.int64(0), T.int64(0), v1], lv575[v0, v1 // T.int64(8)], lv576[v0, v1 // T.int64(32)]) + T.writes(var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0]) + with T.init(): + var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float16(0.0) + var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + lv574[T.int64(0), T.int64(0), v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv576[v0, v1 // T.int64(32)]) + for ax0 in range(T.int64(4096)): + with T.block("T_add"): + v0 = T.axis.spatial(T.int64(4096), ax0) + T.reads(lv570[T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0]) + T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) + p_output0_intermediate[T.int64(0), T.int64(0), v0] = lv570[T.int64(0), T.int64(0), v0] + var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_autogptq_decode_gemv(): + # fmt: off + @T.prim_func(private=True) + def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv9[v_i // T.int64(8), v_j], lv10[lv12[v_i], v_j // T.int64(8)], lv12[v_i], lv11[lv12[v_i], v_j]) + T.writes(decode_intermediate[v_i, v_j]) + decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) - (T.Cast("float16", T.bitwise_and(T.shift_right(lv10[lv12[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) + T.float16(1))) * lv11[lv12[v_i], v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv8[v_i0, v_i1, v_k], decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv8[v_i0, v_i1, v_k] * decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv1613[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1613[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + # fmt: on + + # The GeMV rule does not yet support the inner dim being grouped. + # So the rule is expected to skip transforming this function. + mod = tvm.IRModule({"main": func}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], func) + + +def test_outer_reduction_adreno(): + # fmt: off + @T.prim_func(private=True) + def before( + lv575: T.Buffer((1376, 4096), "uint32"), + lv576: T.Buffer((344, 4096), "float16"), + lv574: T.Buffer((1, 1, 11008), "float16"), + lv570: T.Buffer((1, 1, 4096), "float16"), + p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") + var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") + for i, j in T.grid(11008, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15)))- T.float16(7)) * lv576[v_i // 32, v_j] + for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(1, 1, 4096): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @T.prim_func(private=True) + def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") + var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") + for i, j in T.grid(11008, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv576[v_i // 32, v_j] + for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(1, 1, 4096): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_outer_reduction_adreno_dynamic(): + # fmt: off + @T.prim_func(private=True) + def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + v = T.int64() + lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") + lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") + for i, j in T.grid(T.int64(4096), v): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func(private=True) + def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + v = T.int64() + lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") + lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") + for i, j in T.grid(T.int64(4096), v): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv613[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_blockized_gemv(): + # fmt: off + @T.prim_func(private=True) + def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): + # with T.block("root"): + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + v_expert_id_o = T.axis.spatial(2, expert_id) + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.reduce(1, 0) + T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, 0:16384]) + for i, j in T.grid(16384, 4096): + with T.block("gemv"): + vi_i, vj_i = T.axis.remap("SR", [i, j]) + T.reads(x[0, vj_i], w[indptr[v_expert_id_o], vi_i, vj_i], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, vi_i]) + with T.init(): + o[v_expert_id_o, vi_i] = T.float16(0) + o[v_expert_id_o, vi_i] = o[v_expert_id_o, vi_i] + x[0, vj_i] * w[indptr[v_expert_id_o], vi_i, vj_i] + + @T.prim_func(private=True) + def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + v_expert_id_o = T.axis.spatial(2, expert_id) + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.reduce(1, 0) + T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, 0:16384]) + for u_fused in range(1): + for ax0_fused_0 in T.parallel(128): + for ax0_fused_1 in T.vectorized(128): + for ax1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_fused_1, u_0, u_1 in T.grid(64, 1, 1): + with T.block("gemv"): + v0 = T.axis.spatial(16384, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_fused_0 * 64 + ax1_fused_1) + T.reads(x[0, v1], w[indptr[v_expert_id_o], v0, v1], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, v0]) + with T.init(): + o[v_expert_id_o, v0] = T.float16(0.0) + o[v_expert_id_o, v0] = o[v_expert_id_o, v0] + x[0, v1] * w[indptr[v_expert_id_o], v0, v1] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_func_to_skip(): + @T.prim_func + def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: T.int64): + data_buf = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32", align=8) + output_buf = T.match_buffer( + var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", align=8 + ) + with T.block("exclusive_scan_thrust"): + T.reads() + T.writes() + T.call_packed( + "tvm.contrib.thrust.sum_scan", + T.tvm_stack_make_array( + data_buf.data, T.tvm_stack_make_shape(seq_len * T.int64(8)), 0, 1, 0, T.int64(0) + ), + T.tvm_stack_make_array( + output_buf.data, + T.tvm_stack_make_shape(seq_len * T.int64(8)), + 0, + 1, + 0, + T.int64(0), + ), + T.bool(False), + ) + + # This function should be skipped. + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], before) + + +if __name__ == "__main__": + tvm.testing.main() From 1da944cf5452fd6fd8fb0d1cf5b4a464ca55ef5c Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Wed, 19 Feb 2025 12:06:34 -0500 Subject: [PATCH 4/6] Refactor analysis and scheduling utilities --- python/tvm/dlight/__init__.py | 8 +- python/tvm/dlight/analysis/__init__.py | 31 ++++ .../common_analysis.py} | 0 python/tvm/dlight/analysis/gemv.py | 147 ++++++++++++++++++ python/tvm/dlight/base/__init__.py | 9 -- python/tvm/dlight/base/common_schedules.py | 2 +- python/tvm/dlight/cpu/gemv.py | 130 +--------------- python/tvm/dlight/gpu/fallback.py | 3 +- python/tvm/dlight/gpu/gemv.py | 126 +-------------- python/tvm/dlight/gpu/general_reduction.py | 3 +- python/tvm/dlight/gpu/low_batch_gemv.py | 4 +- python/tvm/dlight/gpu/matmul.py | 13 +- python/tvm/dlight/gpu/reduction.py | 4 +- python/tvm/dlight/gpu/rmsnorm.py | 4 +- python/tvm/dlight/gpu/transpose.py | 7 +- 15 files changed, 210 insertions(+), 281 deletions(-) create mode 100644 python/tvm/dlight/analysis/__init__.py rename python/tvm/dlight/{base/analysis.py => analysis/common_analysis.py} (100%) create mode 100644 python/tvm/dlight/analysis/gemv.py diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index 4c895368fc82..bd70acf00f90 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -17,12 +17,14 @@ """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu from . import cpu -from .base import ( - ApplyDefaultSchedule, +from .analysis import ( BlockInfo, IterInfo, - ScheduleRule, normalize_prim_func, +) +from .base import ( + ApplyDefaultSchedule, + ScheduleRule, try_inline, try_inline_contiguous_spatial, ) diff --git a/python/tvm/dlight/analysis/__init__.py b/python/tvm/dlight/analysis/__init__.py new file mode 100644 index 000000000000..bf68d0855015 --- /dev/null +++ b/python/tvm/dlight/analysis/__init__.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base infra""" +from .common_analysis import ( + BlockInfo, + IterInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, + get_root_block, +) +from .gemv import ( + is_gemv, + normalize, +) diff --git a/python/tvm/dlight/base/analysis.py b/python/tvm/dlight/analysis/common_analysis.py similarity index 100% rename from python/tvm/dlight/base/analysis.py rename to python/tvm/dlight/analysis/common_analysis.py diff --git a/python/tvm/dlight/analysis/gemv.py b/python/tvm/dlight/analysis/gemv.py new file mode 100644 index 000000000000..ccde7a042e32 --- /dev/null +++ b/python/tvm/dlight/analysis/gemv.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Analysis for GEMV.""" +from typing import List, Optional + +from tvm import arith, ir, tir + +from .common_analysis import ( + BlockInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, +) + + +def get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: + """Check if the block is a GEMV. + + Parameters + ---------- + + sch : tir.Schedule + The schedule + + block_info : BlockInfo + The block info to be checked + + + Returns + ------- + ret : Optional[List[tir.Buffer]] + The vector buffers used in the GEMV if it is a GEMV, otherwise None. + """ + block = block_info.block_rv + block_stmt = sch.get(block) + conditions = [] + conditions.append(block_info.is_reduction()) + conditions.append(len(block_stmt.reads) >= 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(get_reduction_expr(block_stmt) is not None) + conditions.append( + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0 + ) + if not all(conditions): + return None + + iter_num = len(block_stmt.iter_vars) + ret = [ + read.buffer + for read in block_stmt.reads + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 + ] + return ret if 0 < len(ret) < len(block_stmt.reads) else None + + +def normalize( + sch: tir.Schedule, + block_info: BlockInfo, +) -> Optional[bool]: + """Normalize the main block.""" + block_stmt: tir.Block = sch.get(block_info.block_rv) + access = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ) + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] + buffers_use_vars.extend( + [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ] + ) + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars + ): + return None + iter_to_info = {i.var: i for i in block_info.iters} + batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + inner_axis = access.args[-1].source.source + is_inner_reduction = iter_to_info[inner_axis].kind == "R" + + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.get(var) + loop = info.loop_rv + is_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we only support the reduction dim being grouped atm + if not is_reduction: + return None + c_loops.append(c_loop) + if is_reduction: + r_loops.append(loop) + elif all([var in buf_vars for buf_vars in buffers_use_vars]): + batch_loops.append(loop) + else: + s_loops.append(loop) + + assert s_loops + assert r_loops + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + if not batch_loops: + batch_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) + sch.fuse(*batch_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction diff --git a/python/tvm/dlight/base/__init__.py b/python/tvm/dlight/base/__init__.py index a19a292fa13e..1e47edf3800a 100644 --- a/python/tvm/dlight/base/__init__.py +++ b/python/tvm/dlight/base/__init__.py @@ -15,15 +15,6 @@ # specific language governing permissions and limitations # under the License. """Base infra""" -from .analysis import ( - BlockInfo, - IterInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, - is_broadcast_epilogue, - normalize_prim_func, -) from .common_schedules import try_inline, try_inline_contiguous_spatial from .schedule_rule import ScheduleRule from .transform import ApplyDefaultSchedule diff --git a/python/tvm/dlight/base/common_schedules.py b/python/tvm/dlight/base/common_schedules.py index fe005cec5d70..c205b78390bc 100644 --- a/python/tvm/dlight/base/common_schedules.py +++ b/python/tvm/dlight/base/common_schedules.py @@ -19,7 +19,7 @@ from tvm import tir -from .analysis import BlockInfo +from ..analysis import BlockInfo def try_inline( diff --git a/python/tvm/dlight/cpu/gemv.py b/python/tvm/dlight/cpu/gemv.py index 8a3fa8b814ca..1599cdb1577f 100644 --- a/python/tvm/dlight/cpu/gemv.py +++ b/python/tvm/dlight/cpu/gemv.py @@ -17,141 +17,19 @@ """A rule for GEMV and DecodeGEMV.""" from typing import List, Optional, Union -from tvm import arith, ir, tir +from tvm import tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, normalize_prim_func, - try_inline_contiguous_spatial, ) +from ..analysis.gemv import is_gemv, normalize +from ..base import try_inline_contiguous_spatial from .base import CPUScheduleRule from .utils import get_extent -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: - # Detect and return `Y` in `X[...] = X[...] + Y` - buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): - return None - if not isinstance(buffer_store.value, tir.Add): - return None - if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, - ): - return None - return buffer_store.value.b - - -def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: - """Check if the block is a GEMV. - - Parameters - ---------- - - sch : tir.Schedule - The schedule - - block_info : BlockInfo - The block info to be checked - - - Returns - ------- - ret : Optional[List[tir.Buffer]] - The vector buffers used in the GEMV if it is a GEMV, otherwise None. - """ - block = block_info.block_rv - block_stmt = sch.get(block) - conditions = [] - conditions.append(block_info.is_reduction()) - conditions.append(len(block_stmt.reads) >= 2) - conditions.append(len(block_stmt.writes) == 1) - conditions.append(_get_reduction_expr(block_stmt) is not None) - conditions.append( - len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) - > 0 - ) - if not all(conditions): - return None - - iter_num = len(block_stmt.iter_vars) - ret = [ - read.buffer - for read in block_stmt.reads - if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num - and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 - ] - return ret if 0 < len(ret) < len(block_stmt.reads) else None - - -def normalize( # pylint: disable=too-many-locals, use-a-generator - sch: tir.Schedule, - block_info: BlockInfo, -) -> Optional[bool]: - """Normalize the main block.""" - block_stmt: tir.Block = sch.get(block_info.block_rv) - access = arith.normalize_to_iter_sum( - detect_dominant_read(block_stmt), - input_iters={i.var: i.dom for i in block_stmt.iter_vars}, - ) - buffers_use_vars = [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.writes - ] - buffers_use_vars.extend( - [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.reads - ] - ) - if collect_vars_used_in_prim_expr(access.base) & set( - iter_var.var for iter_var in block_stmt.iter_vars - ): - return None - iter_to_info = {i.var: i for i in block_info.iters} - batch_loops, s_loops, r_loops, c_loops = [], [], [], [] - inner_axis = access.args[-1].source.source - is_inner_reduction = iter_to_info[inner_axis].kind == "R" - - for split_expr in access.args: - var = split_expr.source.source - info = iter_to_info.get(var) - loop = info.loop_rv - is_reduction = info.kind == "R" - if split_expr.lower_factor > 1: - if c_loops: - return None - loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - # we only support the reduction dim being grouped atm - if not is_reduction: - return None - c_loops.append(c_loop) - if is_reduction: - r_loops.append(loop) - elif all([var in buf_vars for buf_vars in buffers_use_vars]): - batch_loops.append(loop) - else: - s_loops.append(loop) - - assert s_loops - assert r_loops - if not c_loops: - c_loops = [sch.add_unit_loop(block_info.block_rv)] - if not batch_loops: - batch_loops = [sch.add_unit_loop(block_info.block_rv)] - sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) - sch.fuse(*batch_loops) - sch.fuse(*s_loops) - sch.fuse(*r_loops) - return is_inner_reduction - - class GEMV(CPUScheduleRule): """A rule for GEMV and DecodeGEMV.""" diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index 7139c7ea4199..60fc373e8c4d 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -21,7 +21,8 @@ from tvm import tir from tvm.target import Target -from ..base import normalize_prim_func, try_inline +from ..analysis import normalize_prim_func +from ..base import try_inline from . import utils from .base import GPUScheduleRule diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index cff234140e50..c3c4ee7f7817 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -21,139 +21,21 @@ from tvm import arith, ir, tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, detect_dominant_read, is_broadcast_epilogue, + is_gemv, + normalize, normalize_prim_func, - try_inline_contiguous_spatial, ) +from ..base import try_inline_contiguous_spatial from .base import GPUScheduleRule from .utils import auto_vectorize, get_bytes, get_extent -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: - # Detect and return `Y` in `X[...] = X[...] + Y` - buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): - return None - if not isinstance(buffer_store.value, tir.Add): - return None - if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, - ): - return None - return buffer_store.value.b - - -def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: - """Check if the block is a GEMV. - - Parameters - ---------- - - sch : tir.Schedule - The schedule - - block_info : BlockInfo - The block info to be checked - - - Returns - ------- - ret : Optional[List[tir.Buffer]] - The vector buffers used in the GEMV if it is a GEMV, otherwise None. - """ - block = block_info.block_rv - block_stmt = sch.get(block) - conditions = [] - conditions.append(block_info.is_reduction()) - conditions.append(len(block_stmt.reads) >= 2) - conditions.append(len(block_stmt.writes) == 1) - conditions.append(_get_reduction_expr(block_stmt) is not None) - conditions.append( - len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) - > 0 - ) - if not all(conditions): - return None - - iter_num = len(block_stmt.iter_vars) - ret = [ - read.buffer - for read in block_stmt.reads - if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num - and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 - ] - return ret if 0 < len(ret) < len(block_stmt.reads) else None - - -def normalize( - sch: tir.Schedule, - block_info: BlockInfo, -) -> Optional[bool]: - """Normalize the main block.""" - block_stmt: tir.Block = sch.get(block_info.block_rv) - access = arith.normalize_to_iter_sum( - detect_dominant_read(block_stmt), - input_iters={i.var: i.dom for i in block_stmt.iter_vars}, - ) - buffers_use_vars = [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.writes - ] - buffers_use_vars.extend( - [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.reads - ] - ) - if collect_vars_used_in_prim_expr(access.base) & set( - iter_var.var for iter_var in block_stmt.iter_vars - ): - return None - iter_to_info = {i.var: i for i in block_info.iters} - batch_loops, s_loops, r_loops, c_loops = [], [], [], [] - inner_axis = access.args[-1].source.source - is_inner_reduction = iter_to_info[inner_axis].kind == "R" - - for split_expr in access.args: - var = split_expr.source.source - info = iter_to_info.get(var) - loop = info.loop_rv - is_reduction = info.kind == "R" - if split_expr.lower_factor > 1: - if c_loops: - return None - loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - # we only support the reduction dim being grouped atm - if not is_reduction: - return None - c_loops.append(c_loop) - if is_reduction: - r_loops.append(loop) - elif all([var in buf_vars for buf_vars in buffers_use_vars]): - batch_loops.append(loop) - else: - s_loops.append(loop) - - assert s_loops - assert r_loops - if not c_loops: - c_loops = [sch.add_unit_loop(block_info.block_rv)] - if not batch_loops: - batch_loops = [sch.add_unit_loop(block_info.block_rv)] - sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) - sch.fuse(*batch_loops) - sch.fuse(*s_loops) - sch.fuse(*r_loops) - return is_inner_reduction - - class GEMV(GPUScheduleRule): """A rule for GEMV and DecodeGEMV.""" diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index 404b73a6f0cc..a068e732b986 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -21,7 +21,8 @@ from tvm import arith, tir from tvm.target import Target -from ..base import normalize_prim_func, try_inline_contiguous_spatial +from ..analysis import normalize_prim_func +from ..base import try_inline_contiguous_spatial from .base import GPUScheduleRule diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index b528086a1626..f2dfa3f50e55 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -21,14 +21,14 @@ from tvm import arith, ir, tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, is_broadcast_epilogue, normalize_prim_func, - try_inline_contiguous_spatial, ) +from ..base import try_inline_contiguous_spatial from .base import GPUScheduleRule from .utils import auto_vectorize, get_bytes, get_extent diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index d9d4b7ebd4d2..368552c88d43 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -22,13 +22,13 @@ from tvm import tir from tvm.ir import Range +from tvm.script import tir as T from tvm.target import Target from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV -from tvm.script import tir as T -from ..base import analysis, BlockInfo, IterInfo +from ..analysis import BlockInfo, IterInfo, get_root_block from .base import GPUScheduleRule @@ -358,7 +358,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) reduction_blocks = get_reduction_blocks(sch, blocks) @@ -499,7 +499,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) if "dlight.do_not_tensorize" in func.attrs.keys(): @@ -720,7 +720,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) if "dlight.do_not_tensorize" in func.attrs.keys(): @@ -971,7 +971,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring return None sch = tir.Schedule(func) config = self.get_configs(target) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) reduction_blocks = get_reduction_blocks(sch, blocks) @@ -1130,7 +1130,6 @@ def sch_outer_reduction( reduction_block: tir.schedule.BlockRV, blocks: List[tir.schedule.BlockRV], ) -> Optional[tir.Schedule]: - """Get vectorization factor""" def get_max_factor(n, factors): diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index fc63e4836849..4bc0d2b3efa7 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -21,13 +21,13 @@ from tvm import arith, ir, tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, detect_dominant_read, is_broadcast_epilogue, normalize_prim_func, - try_inline_contiguous_spatial, ) +from ..base import try_inline_contiguous_spatial from . import utils from .base import GPUScheduleRule diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 4047721c9aa8..5dc6887c782c 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -19,9 +19,9 @@ import tvm from tvm import tir -from tvm.tir import Block, BufferStore -from tvm.tir.expr import Cast, BufferLoad, Call from tvm.target import Target +from tvm.tir import Block, BufferStore +from tvm.tir.expr import BufferLoad, Call, Cast from ..base import ScheduleRule diff --git a/python/tvm/dlight/gpu/transpose.py b/python/tvm/dlight/gpu/transpose.py index 3bef3d61e536..125af538cdb8 100644 --- a/python/tvm/dlight/gpu/transpose.py +++ b/python/tvm/dlight/gpu/transpose.py @@ -22,11 +22,8 @@ from tvm.tir import Schedule from tvm.tir.schedule import BlockRV -from ..base import ( - detect_dominant_read, - normalize_prim_func, - try_inline_contiguous_spatial, -) +from ..analysis import detect_dominant_read, normalize_prim_func +from ..base import try_inline_contiguous_spatial from .base import GPUScheduleRule From 33b406be517da287cf8789ed1c5aa28d7eae89d8 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Wed, 19 Feb 2025 12:34:24 -0500 Subject: [PATCH 5/6] lint --- python/tvm/dlight/analysis/gemv.py | 17 ++++++++++++++++- python/tvm/dlight/gpu/gemv.py | 5 +---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/tvm/dlight/analysis/gemv.py b/python/tvm/dlight/analysis/gemv.py index ccde7a042e32..c502081ba320 100644 --- a/python/tvm/dlight/analysis/gemv.py +++ b/python/tvm/dlight/analysis/gemv.py @@ -28,7 +28,22 @@ def get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: - # Detect and return `Y` in `X[...] = X[...] + Y` + """Extracts the reduction expression from a TIR block. + + This function checks whether the given TIR block follows a reduction pattern + of the form `X[...] = X[...] + Y` and returns `Y` as the reduction expression. + + Parameters: + ---------- + block : tir.Block + The TIR block to analyze. + + Returns: + ------- + Optional[tir.PrimExpr] + The reduction expression (`Y`) if detected, otherwise None. + """ + buffer_store = block.body if not isinstance(buffer_store, tir.BufferStore): return None diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index c3c4ee7f7817..32546af6a89e 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -18,14 +18,11 @@ from functools import reduce from typing import List, Optional, Union -from tvm import arith, ir, tir +from tvm import tir from tvm.target import Target from ..analysis import ( BlockInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, is_broadcast_epilogue, is_gemv, normalize, From 6180d2a32bb893b10cf3dc66d9c0e430fb3d6956 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Mon, 24 Feb 2025 13:18:39 -0500 Subject: [PATCH 6/6] Fix duplicated schedule creation and utils.py --- python/tvm/dlight/base/__init__.py | 7 ++++ python/tvm/dlight/{gpu => base}/utils.py | 0 python/tvm/dlight/cpu/gemv.py | 9 ++---- python/tvm/dlight/cpu/utils.py | 41 ------------------------ python/tvm/dlight/gpu/fallback.py | 4 +-- python/tvm/dlight/gpu/gemv.py | 3 +- python/tvm/dlight/gpu/low_batch_gemv.py | 3 +- python/tvm/dlight/gpu/reduction.py | 5 ++- 8 files changed, 15 insertions(+), 57 deletions(-) rename python/tvm/dlight/{gpu => base}/utils.py (100%) delete mode 100644 python/tvm/dlight/cpu/utils.py diff --git a/python/tvm/dlight/base/__init__.py b/python/tvm/dlight/base/__init__.py index 1e47edf3800a..9d90c4f8e171 100644 --- a/python/tvm/dlight/base/__init__.py +++ b/python/tvm/dlight/base/__init__.py @@ -18,3 +18,10 @@ from .common_schedules import try_inline, try_inline_contiguous_spatial from .schedule_rule import ScheduleRule from .transform import ApplyDefaultSchedule +from .utils import ( + auto_vectorize, + get_bytes, + get_extent, + max_threads_per_block, + suggest_threads_per_block, +) diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/base/utils.py similarity index 100% rename from python/tvm/dlight/gpu/utils.py rename to python/tvm/dlight/base/utils.py diff --git a/python/tvm/dlight/cpu/gemv.py b/python/tvm/dlight/cpu/gemv.py index 1599cdb1577f..15b47de919a7 100644 --- a/python/tvm/dlight/cpu/gemv.py +++ b/python/tvm/dlight/cpu/gemv.py @@ -20,14 +20,10 @@ from tvm import tir from tvm.target import Target -from ..analysis import ( - BlockInfo, - normalize_prim_func, -) +from ..analysis import BlockInfo, normalize_prim_func from ..analysis.gemv import is_gemv, normalize -from ..base import try_inline_contiguous_spatial +from ..base import get_extent, try_inline_contiguous_spatial from .base import CPUScheduleRule -from .utils import get_extent class GEMV(CPUScheduleRule): @@ -42,7 +38,6 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) block_infos = try_inline_contiguous_spatial(sch, block_infos) if block_infos is None: diff --git a/python/tvm/dlight/cpu/utils.py b/python/tvm/dlight/cpu/utils.py deleted file mode 100644 index 478baa89bf02..000000000000 --- a/python/tvm/dlight/cpu/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -"""Utility methods for generic CPU.""" -from typing import Union - -from tvm import DataType, tir - - -def get_bytes(dtype: Union[DataType, str]) -> int: - if isinstance(dtype, str): - dtype = DataType(dtype) - return dtype.itemsize() - - -def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): - loop: tir.For = sch.get(loop_rv) - return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent - - -def auto_vectorize(sch: tir.Schedule, loop: tir.schedule.LoopRV, max_vec: int): - """Auto vectorize the loop.""" - extent = get_extent(sch, loop) - if not isinstance(extent, int): - return - v = loop if extent <= max_vec else sch.split(loop, factors=[None, max_vec])[-1] - sch.vectorize(v) diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index 60fc373e8c4d..bcbfda791fb3 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -21,9 +21,9 @@ from tvm import tir from tvm.target import Target +from .. import base from ..analysis import normalize_prim_func from ..base import try_inline -from . import utils from .base import GPUScheduleRule @@ -41,7 +41,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring ) -> tir.Schedule: if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None - max_threads_per_block = utils.max_threads_per_block(target) + max_threads_per_block = base.max_threads_per_block(target) sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 32546af6a89e..ebb19ad72c3a 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -28,9 +28,8 @@ normalize, normalize_prim_func, ) -from ..base import try_inline_contiguous_spatial +from ..base import auto_vectorize, get_bytes, get_extent, try_inline_contiguous_spatial from .base import GPUScheduleRule -from .utils import auto_vectorize, get_bytes, get_extent class GEMV(GPUScheduleRule): diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index f2dfa3f50e55..f5e3669ad0f3 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -28,9 +28,8 @@ is_broadcast_epilogue, normalize_prim_func, ) -from ..base import try_inline_contiguous_spatial +from ..base import auto_vectorize, get_bytes, get_extent, try_inline_contiguous_spatial from .base import GPUScheduleRule -from .utils import auto_vectorize, get_bytes, get_extent def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 4bc0d2b3efa7..9851bb9800fa 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -27,8 +27,7 @@ is_broadcast_epilogue, normalize_prim_func, ) -from ..base import try_inline_contiguous_spatial -from . import utils +from ..base import suggest_threads_per_block, try_inline_contiguous_spatial from .base import GPUScheduleRule @@ -181,7 +180,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments ): # pylint: disable=invalid-name _, r, _ = sch.get_loops(block) - (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking + (len_tx,) = suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking target, [sch.get(r)] )