From 7bbb25362e89055e9e63b5af403d1af3b9477e18 Mon Sep 17 00:00:00 2001 From: Celve Date: Wed, 3 Jan 2024 12:24:29 +0000 Subject: [PATCH 1/6] [Unity][DLight] Introduce Specific Rule for RMSNorm --- include/tvm/topi/nn/rms_norm.h | 24 +- python/tvm/dlight/gpu/__init__.py | 1 + python/tvm/dlight/gpu/rmsnorm.py | 141 ++++++++++++ tests/python/dlight/test_gpu_rmsnorm.py | 277 ++++++++++++++++++++++++ 4 files changed, 440 insertions(+), 3 deletions(-) create mode 100644 python/tvm/dlight/gpu/rmsnorm.py create mode 100644 tests/python/dlight/test_gpu_rmsnorm.py diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index ba2f7e49ac98..7e95000f1ee2 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -67,6 +67,25 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arrayshape[i]; } + auto rsqrt_func = [&](const Array& indices) { + Array non_reduce_indices; + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { + non_reduce_indices.push_back(indices[i]); + } + } + auto output = + tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + return output; + }; + auto rsqrt_shape = Array(); + for (int i = 0, n = static_cast(data_fp32->shape.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { + rsqrt_shape.push_back(data_fp32->shape[i]); + } + } + auto rsqrt = tvm::te::compute(rsqrt_shape, rsqrt_func, "rsqrt", tag); + auto rms_norm_func = [&](const Array& indices) { Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { @@ -76,12 +95,11 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arrayshape, rms_norm_func, name, tag); + return cast(rms_norm, data_type); } diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py index f48bdb2c8182..7db383a161cd 100644 --- a/python/tvm/dlight/gpu/__init__.py +++ b/python/tvm/dlight/gpu/__init__.py @@ -24,3 +24,4 @@ from .reduction import Reduction from .transpose import Transpose from .general_reduction import GeneralReduction +from .rmsnorm import RMSNorm diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py new file mode 100644 index 000000000000..e1e056843095 --- /dev/null +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""A RMS norm schedule rule for GPU operators.""" + +import tvm +from tvm import tir +from tvm.tir import Block, BufferStore +from tvm.tir.expr import Cast, BufferLoad +from tvm.target import Target + +from ..base import ScheduleRule + + +def identify_cast_or_load_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + # check types + if isinstance(store.value, BufferLoad): + load = store.value + elif isinstance(store.value, Cast): + load = store.value.value + if not isinstance(load, BufferLoad): + return False + else: + return False + + # check indices + if len(load.indices) != len(store.indices): + return False + + for lhs, rhs in zip(load.indices, store.indices): + if not lhs.same_as(rhs): + return False + + return True + + +def identify_rsqrt_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + try: + op = block.body.value.op + except Exception: + return False + + return op == tvm.ir.op.Op.get("tir.rsqrt") + + +class RMSNorm(ScheduleRule): + """A rule for RMS norm.""" + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + if target.kind.name == "cuda": + tx = 512 + else: + tx = 64 + + sch = tir.Schedule(func) + root = sch.get_block(name="root", func_name="main") + + blocks = sch.get_child_blocks(root) + + if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]): + return None + + read = sch.cache_read(block=blocks[0], read_buffer_index=0, storage_scope="local") + write = sch.cache_write(block=blocks[-1], write_buffer_index=0, storage_scope="local") + + for block in blocks: + if identify_cast_or_load_block(sch.get(block)): + sch.compute_inline(block) + + blocks = sch.get_child_blocks(root) + + read, sqr, redsum, rsqrt, norm, write = blocks + + if not identify_rsqrt_block(sch.get(rsqrt)): + return None + + for name in [read, sqr, redsum, norm]: + sch.transform_block_layout( + block=name, + index_map=lambda v_ax0, v_ax1, v_ax2: ( + v_ax1, + v_ax2, + ), + ) + sch.transform_block_layout(block=rsqrt, index_map=lambda v_ax0, v_ax1: (v_ax1,)) + + block_loop, loops = sch.get_loops(block=read) + thread_loop, repeated_loop, vec_loop = sch.split( + loop=loops, factors=[tx, None, 8], preserve_unit_iters=True + ) + sch.bind(block_loop, thread_axis="blockIdx.x") + sch.bind(thread_loop, thread_axis="threadIdx.x") + sch.vectorize(sch.get_loops(block=read)[-1]) + sch.reverse_compute_at(block=sqr, loop=thread_loop) + sch.reverse_compute_at(block=redsum, loop=thread_loop) + + sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) + sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) + block_loop, loops = sch.get_loops(block=norm) + thread_loop, repeated_loop, vec_loop = sch.split( + loop=loops, factors=[tx, None, 8], preserve_unit_iters=True + ) + sch.bind(thread_loop, thread_axis="threadIdx.x") + + sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) + sch.vectorize(sch.get_loops(block=write)[-1]) + + sch.set_scope(block=sqr, buffer_index=0, storage_scope="local") + sch.set_scope(block=redsum, buffer_index=0, storage_scope="local") + sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared") + sch.set_scope(block=norm, buffer_index=0, storage_scope="local") + + return sch diff --git a/tests/python/dlight/test_gpu_rmsnorm.py b/tests/python/dlight/test_gpu_rmsnorm.py new file mode 100644 index 000000000000..f128c48c06b3 --- /dev/null +++ b/tests/python/dlight/test_gpu_rmsnorm.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm.testing + +from tvm.ir import IRModule, assert_structural_equal +from tvm import dlight as dl +from tvm.script import ir as I +from tvm.target import Target +from tvm.script import tir as T + + +def _check(mod_before: IRModule, mod_after: IRModule): + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.RMSNorm(), + )(mod_before) + assert_structural_equal(mod, mod_after) + + +def test_rms_norm_with_casting(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096), "float16") + T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") + # with T.block("root"): + T_cast_1 = T.alloc_buffer((1, n, 4096)) + T_multiply = T.alloc_buffer((1, n, 4096)) + T_multiply_red = T.alloc_buffer((1, n)) + rsqrt = T.alloc_buffer((1, n)) + T_cast_2 = T.alloc_buffer((4096,)) + T_rms_norm = T.alloc_buffer((1, n, 4096)) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(data[v_ax0, v_ax1, v_ax2]) + T.writes(T_cast_1[v_ax0, v_ax1, v_ax2]) + T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data[v_ax0, v_ax1, v_ax2]) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_cast_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, k2 in T.grid(1, n, 4096): + with T.block("T_multiply_red"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(T_multiply[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red[v_ax0, v_ax1]) + with T.init(): + T_multiply_red[v_ax0, v_ax1] = T.float32(0) + T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2] + for ax0, ax1 in T.grid(1, n): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0 in range(4096): + with T.block("T_cast_1"): + v_ax0 = T.axis.spatial(4096, ax0) + T.reads(weight[v_ax0]) + T.writes(T_cast_2[v_ax0]) + T_cast_2[v_ax0] = T.Cast("float32", weight[v_ax0]) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2]) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2] + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_cast_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T.writes(T_cast[v_ax0, v_ax1, v_ax2]) + T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2]) + + @I.ir_module + class After: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096), "float16") + T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") + # with T.block("root"): + T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local") + T_multiply_red_local = T.alloc_buffer((1, n), scope="local") + rsqrt_shared = T.alloc_buffer((1, n), scope="shared") + T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") + data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local") + for ax0 in T.thread_binding(n, thread="blockIdx.x"): + for ax1_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax1_1 in range(1): + for ax1_2 in T.vectorized(8): + with T.block("data_local"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2) + T.reads(data[0, v0, v1]) + T.writes(data_local[0, v0, v1]) + data_local[0, v0, v1] = data[0, v0, v1] + for ax0_1 in range(8): + with T.block("T_multiply"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1) + T.reads(data_local[0, v0, v1]) + T.writes(T_multiply_local[0, v0, v1]) + T_multiply_local[0, v0, v1] = T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", data_local[0, v0, v1]) + for ax0_1 in range(8): + with T.block("T_multiply_red"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1) + T.reads(T_multiply_local[0, v0, v1]) + T.writes(T_multiply_red_local[0, v0]) + with T.init(): + T_multiply_red_local[0, v0] = T.float32(0) + T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1] + with T.block("rsqrt"): + v0 = T.axis.spatial(n, ax0) + T.reads(T_multiply_red_local[0, v0]) + T.writes(rsqrt_shared[0, v0]) + rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax0_1, ax0_2 in T.grid(1, 8): + with T.block("T_rms_norm"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1]) + T.writes(T_rms_norm_local[0, v0, v1]) + T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", weight[v1]) + for ax0_1 in T.vectorized(8): + with T.block("T_cast_local"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1) + T.reads(T_rms_norm_local[v0, v1, v2]) + T.writes(T_cast[v0, v1, v2]) + T_cast[v0, v1, v2] = T.Cast("float16", T_rms_norm_local[v0, v1, v2]) + # fmt: on + _check(Before, After) + + +def test_rms_norm_without_casting(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096)) + T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) + # with T.block("root"): + T_multiply = T.alloc_buffer((1, n, 4096)) + T_multiply_red = T.alloc_buffer((1, n)) + rsqrt = T.alloc_buffer((1, n)) + T_rms_norm = T.alloc_buffer((1, n, 4096)) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(data[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = data[v_ax0, v_ax1, v_ax2] * data[v_ax0, v_ax1, v_ax2] + for ax0, ax1, k2 in T.grid(1, n, 4096): + with T.block("T_multiply_red"): + v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) + T.reads(T_multiply[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red[v_ax0, v_ax1]) + with T.init(): + T_multiply_red[v_ax0, v_ax1] = T.float32(0) + T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2] + for ax0, ax1 in T.grid(1, n): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(rsqrt[v_ax0, v_ax1], data[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * data[v_ax0, v_ax1, v_ax2] * weight[v_ax2] + for ax0, ax1, ax2 in T.grid(1, n, 4096): + with T.block("T_cast_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T.writes(T_cast[v_ax0, v_ax1, v_ax2]) + T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1, v_ax2] + + @I.ir_module + class After: + @T.prim_func + def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + data = T.match_buffer(var_data, (1, n, 4096)) + T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) + # with T.block("root"): + T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local") + T_multiply_red_local = T.alloc_buffer((1, n), scope="local") + rsqrt_shared = T.alloc_buffer((1, n), scope="shared") + T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") + data_local = T.alloc_buffer((1, n, 4096), scope="local") + for ax0 in T.thread_binding(n, thread="blockIdx.x"): + for ax1_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax1_1 in range(1): + for ax1_2 in T.vectorized(8): + with T.block("data_local"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2) + T.reads(data[0, v0, v1]) + T.writes(data_local[0, v0, v1]) + data_local[0, v0, v1] = data[0, v0, v1] + for ax0_1 in range(8): + with T.block("T_multiply"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1) + T.reads(data_local[0, v0, v1]) + T.writes(T_multiply_local[0, v0, v1]) + T_multiply_local[0, v0, v1] = data_local[0, v0, v1] * data_local[0, v0, v1] + for ax0_1 in range(8): + with T.block("T_multiply_red"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1) + T.reads(T_multiply_local[0, v0, v1]) + T.writes(T_multiply_red_local[0, v0]) + with T.init(): + T_multiply_red_local[0, v0] = T.float32(0) + T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1] + with T.block("rsqrt"): + v0 = T.axis.spatial(n, ax0) + T.reads(T_multiply_red_local[0, v0]) + T.writes(rsqrt_shared[0, v0]) + rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax0_1, ax0_2 in T.grid(1, 8): + with T.block("T_rms_norm"): + v0 = T.axis.spatial(n, ax0) + v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1]) + T.writes(T_rms_norm_local[0, v0, v1]) + T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * data_local[0, v0, v1] * weight[v1] + for ax0_1 in T.vectorized(8): + with T.block("T_cast_local"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1) + T.reads(T_rms_norm_local[v0, v1, v2]) + T.writes(T_cast[v0, v1, v2]) + T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2] + # fmt: on + _check(Before, After) + + +if __name__ == "__main__": + tvm.testing.main() From 84f407a73d3b3d9afd8dac24be4c2d6f004bc270 Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 4 Jan 2024 07:20:44 +0000 Subject: [PATCH 2/6] fix: remove unused variables --- python/tvm/dlight/gpu/rmsnorm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index e1e056843095..6295bcb2dea6 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -113,9 +113,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.transform_block_layout(block=rsqrt, index_map=lambda v_ax0, v_ax1: (v_ax1,)) block_loop, loops = sch.get_loops(block=read) - thread_loop, repeated_loop, vec_loop = sch.split( - loop=loops, factors=[tx, None, 8], preserve_unit_iters=True - ) + thread_loop, _, _ = sch.split(loop=loops, factors=[tx, None, 8], preserve_unit_iters=True) sch.bind(block_loop, thread_axis="blockIdx.x") sch.bind(thread_loop, thread_axis="threadIdx.x") sch.vectorize(sch.get_loops(block=read)[-1]) @@ -125,9 +123,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) block_loop, loops = sch.get_loops(block=norm) - thread_loop, repeated_loop, vec_loop = sch.split( - loop=loops, factors=[tx, None, 8], preserve_unit_iters=True - ) + thread_loop, _, _ = sch.split(loop=loops, factors=[tx, None, 8], preserve_unit_iters=True) sch.bind(thread_loop, thread_axis="threadIdx.x") sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) From 802d97a949be298740d4a6e2ca3f930197cecca7 Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 4 Jan 2024 08:18:02 +0000 Subject: [PATCH 3/6] fix: rename invalid variables --- python/tvm/dlight/gpu/rmsnorm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 6295bcb2dea6..7bd49e1a24c5 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -76,9 +76,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring _: bool, ) -> tir.Schedule: if target.kind.name == "cuda": - tx = 512 + num_tx = 512 else: - tx = 64 + num_tx = 64 sch = tir.Schedule(func) root = sch.get_block(name="root", func_name="main") @@ -113,7 +113,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.transform_block_layout(block=rsqrt, index_map=lambda v_ax0, v_ax1: (v_ax1,)) block_loop, loops = sch.get_loops(block=read) - thread_loop, _, _ = sch.split(loop=loops, factors=[tx, None, 8], preserve_unit_iters=True) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) sch.bind(block_loop, thread_axis="blockIdx.x") sch.bind(thread_loop, thread_axis="threadIdx.x") sch.vectorize(sch.get_loops(block=read)[-1]) @@ -123,7 +125,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) block_loop, loops = sch.get_loops(block=norm) - thread_loop, _, _ = sch.split(loop=loops, factors=[tx, None, 8], preserve_unit_iters=True) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) sch.bind(thread_loop, thread_axis="threadIdx.x") sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) From 4bedd836574351efb89d2d49f2097599a636cc1b Mon Sep 17 00:00:00 2001 From: Celve Date: Thu, 4 Jan 2024 08:43:25 +0000 Subject: [PATCH 4/6] fix: deal with too general exception --- python/tvm/dlight/gpu/rmsnorm.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 7bd49e1a24c5..4f6960f3aef4 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -20,7 +20,7 @@ import tvm from tvm import tir from tvm.tir import Block, BufferStore -from tvm.tir.expr import Cast, BufferLoad +from tvm.tir.expr import Cast, BufferLoad, Call from tvm.target import Target from ..base import ScheduleRule @@ -58,10 +58,15 @@ def identify_cast_or_load_block(block: Block) -> bool: def identify_rsqrt_block(block: Block) -> bool: if len(block.reads) != 1 or len(block.writes) != 1: return False - try: - op = block.body.value.op - except Exception: + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + if not isinstance(store.value, Call): return False + call = store.value + op = call.op return op == tvm.ir.op.Op.get("tir.rsqrt") From 0f030cd0c44d1c64d8285160994264f57235904e Mon Sep 17 00:00:00 2001 From: Celve Date: Fri, 5 Jan 2024 05:39:32 +0000 Subject: [PATCH 5/6] fix: update tests --- .../relax/test_transform_legalize_ops_nn.py | 100 +++++++++++------- 1 file changed, 64 insertions(+), 36 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 74da77f7d8c5..07fbc3419b98 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2773,9 +2773,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) - T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + rsqrt = T.alloc_buffer((T.int64(2), T.int64(3))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast"): @@ -2783,12 +2784,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2803,12 +2798,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa with T.init(): T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2842,9 +2849,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) - T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + rsqrt = T.alloc_buffer((T.int64(2), T.int64(3))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast"): @@ -2852,12 +2860,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_ax3]) - for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1]) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2872,12 +2874,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa with T.init(): T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1]) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2918,9 +2932,10 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): T_cast = T.match_buffer(var_T_cast, (n, s, f)) # with T.block("root"): T_cast_1 = T.alloc_buffer((n, s, f)) - T_cast_2 = T.alloc_buffer((s, f)) T_multiply = T.alloc_buffer((n, s, f)) T_multiply_red = T.alloc_buffer((n,)) + rsqrt = T.alloc_buffer((n,)) + T_cast_2 = T.alloc_buffer((s, f)) T_rms_norm = T.alloc_buffer((n, s, f)) for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_cast"): @@ -2928,12 +2943,6 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): T.reads(A[v_ax0, v_ax1, v_ax2]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2]) T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] - for ax0, ax1 in T.grid(s, f): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) @@ -2948,12 +2957,24 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): with T.init(): T_multiply_red[v_ax0] = T.float32(0) T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2] + for ax0 in range(n): + with T.block("rsqrt"): + v_ax0 = T.axis.spatial(n, ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(rsqrt[v_ax0]) + rsqrt[v_ax0] = T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(s, f): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2], T_multiply_red[v_ax0]) + T.reads(rsqrt[v_ax0], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) - T_rms_norm[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2] * T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(n, s, f): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) @@ -2990,9 +3011,10 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) - T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + rsqrt = T.alloc_buffer((T.int64(2), T.int64(3))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast"): @@ -3000,12 +3022,6 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[v_ax0, v_ax1]) - T.writes(T_cast_2[v_ax0, v_ax1]) - T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -3020,12 +3036,24 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa with T.init(): T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("rsqrt"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_red[v_ax0, v_ax1]) + T.writes(rsqrt[v_ax0, v_ax1]) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1]) + T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) From 84ba0065b8dda6fa618f12eee7b48a7705a47a32 Mon Sep 17 00:00:00 2001 From: Celve Date: Fri, 5 Jan 2024 16:43:26 +0000 Subject: [PATCH 6/6] feat: make rule more general --- python/tvm/dlight/gpu/rmsnorm.py | 12 +- tests/python/dlight/test_gpu_rmsnorm.py | 146 +++++++++++++----------- 2 files changed, 81 insertions(+), 77 deletions(-) diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 4f6960f3aef4..f8b2bb4a172d 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -107,15 +107,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not identify_rsqrt_block(sch.get(rsqrt)): return None - for name in [read, sqr, redsum, norm]: - sch.transform_block_layout( - block=name, - index_map=lambda v_ax0, v_ax1, v_ax2: ( - v_ax1, - v_ax2, - ), - ) - sch.transform_block_layout(block=rsqrt, index_map=lambda v_ax0, v_ax1: (v_ax1,)) + for name in [read, sqr, redsum, rsqrt, norm, write]: + loops = sch.get_loops(name) + sch.fuse(*loops[:-1]) block_loop, loops = sch.get_loops(block=read) thread_loop, _, _ = sch.split( diff --git a/tests/python/dlight/test_gpu_rmsnorm.py b/tests/python/dlight/test_gpu_rmsnorm.py index f128c48c06b3..301dac5c66ac 100644 --- a/tests/python/dlight/test_gpu_rmsnorm.py +++ b/tests/python/dlight/test_gpu_rmsnorm.py @@ -109,50 +109,55 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T rsqrt_shared = T.alloc_buffer((1, n), scope="shared") T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local") - for ax0 in T.thread_binding(n, thread="blockIdx.x"): - for ax1_0 in T.thread_binding(512, thread="threadIdx.x"): - for ax1_1 in range(1): - for ax1_2 in T.vectorized(8): + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax2_1 in range(1): + for ax2_2 in T.vectorized(8): with T.block("data_local"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2) - T.reads(data[0, v0, v1]) - T.writes(data_local[0, v0, v1]) - data_local[0, v0, v1] = data[0, v0, v1] - for ax0_1 in range(8): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) + T.reads(data[v0, v1, v2]) + T.writes(data_local[v0, v1, v2]) + data_local[v0, v1, v2] = data[v0, v1, v2] + for ax0 in range(8): with T.block("T_multiply"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1) - T.reads(data_local[0, v0, v1]) - T.writes(T_multiply_local[0, v0, v1]) - T_multiply_local[0, v0, v1] = T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", data_local[0, v0, v1]) - for ax0_1 in range(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) + T.reads(data_local[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) + T_multiply_local[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) + for ax0 in range(8): with T.block("T_multiply_red"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1) - T.reads(T_multiply_local[0, v0, v1]) - T.writes(T_multiply_red_local[0, v0]) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) + T.reads(T_multiply_local[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red_local[v_ax0, v_ax1]) with T.init(): - T_multiply_red_local[0, v0] = T.float32(0) - T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1] + T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) + T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] with T.block("rsqrt"): - v0 = T.axis.spatial(n, ax0) - T.reads(T_multiply_red_local[0, v0]) - T.writes(rsqrt_shared[0, v0]) - rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + T.reads(T_multiply_red_local[v_ax0, v_ax1]) + T.writes(rsqrt_shared[v_ax0, v_ax1]) + rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): for ax0_1, ax0_2 in T.grid(1, 8): with T.block("T_rms_norm"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) - T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1]) - T.writes(T_rms_norm_local[0, v0, v1]) - T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * T.Cast("float32", data_local[0, v0, v1]) * T.Cast("float32", weight[v1]) - for ax0_1 in T.vectorized(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) + T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", weight[v_ax2]) + for ax0 in T.vectorized(8): with T.block("T_cast_local"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(n, ax0) - v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) T.reads(T_rms_norm_local[v0, v1, v2]) T.writes(T_cast[v0, v1, v2]) T_cast[v0, v1, v2] = T.Cast("float16", T_rms_norm_local[v0, v1, v2]) @@ -222,50 +227,55 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T rsqrt_shared = T.alloc_buffer((1, n), scope="shared") T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local") data_local = T.alloc_buffer((1, n, 4096), scope="local") - for ax0 in T.thread_binding(n, thread="blockIdx.x"): - for ax1_0 in T.thread_binding(512, thread="threadIdx.x"): - for ax1_1 in range(1): - for ax1_2 in T.vectorized(8): + for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): + for ax2_1 in range(1): + for ax2_2 in T.vectorized(8): with T.block("data_local"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax1_1 * 8 + ax1_2) - T.reads(data[0, v0, v1]) - T.writes(data_local[0, v0, v1]) - data_local[0, v0, v1] = data[0, v0, v1] - for ax0_1 in range(8): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) + T.reads(data[v0, v1, v2]) + T.writes(data_local[v0, v1, v2]) + data_local[v0, v1, v2] = data[v0, v1, v2] + for ax0 in range(8): with T.block("T_multiply"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax1_0 * 8 + ax0_1) - T.reads(data_local[0, v0, v1]) - T.writes(T_multiply_local[0, v0, v1]) - T_multiply_local[0, v0, v1] = data_local[0, v0, v1] * data_local[0, v0, v1] - for ax0_1 in range(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) + T.reads(data_local[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) + T_multiply_local[v_ax0, v_ax1, v_ax2] = data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2] + for ax0 in range(8): with T.block("T_multiply_red"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.reduce(4096, ax1_0 * 8 + ax0_1) - T.reads(T_multiply_local[0, v0, v1]) - T.writes(T_multiply_red_local[0, v0]) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) + T.reads(T_multiply_local[v_ax0, v_ax1, v_k2]) + T.writes(T_multiply_red_local[v_ax0, v_ax1]) with T.init(): - T_multiply_red_local[0, v0] = T.float32(0) - T_multiply_red_local[0, v0] = T_multiply_red_local[0, v0] + T_multiply_local[0, v0, v1] + T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) + T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] with T.block("rsqrt"): - v0 = T.axis.spatial(n, ax0) - T.reads(T_multiply_red_local[0, v0]) - T.writes(rsqrt_shared[0, v0]) - rsqrt_shared[0, v0] = T.rsqrt(T_multiply_red_local[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + T.reads(T_multiply_red_local[v_ax0, v_ax1]) + T.writes(rsqrt_shared[v_ax0, v_ax1]) + rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): for ax0_1, ax0_2 in T.grid(1, 8): with T.block("T_rms_norm"): - v0 = T.axis.spatial(n, ax0) - v1 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) - T.reads(rsqrt_shared[0, v0], data_local[0, v0, v1], weight[v1]) - T.writes(T_rms_norm_local[0, v0, v1]) - T_rms_norm_local[0, v0, v1] = rsqrt_shared[0, v0] * data_local[0, v0, v1] * weight[v1] - for ax0_1 in T.vectorized(8): + v_ax0 = T.axis.spatial(1, 0) + v_ax1 = T.axis.spatial(n, ax0_ax1_fused) + v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) + T.reads(rsqrt_shared[v_ax0, v_ax1], data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2]) + T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) + T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2] + for ax0 in T.vectorized(8): with T.block("T_cast_local"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(n, ax0) - v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1) + v1 = T.axis.spatial(n, ax0_ax1_fused) + v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) T.reads(T_rms_norm_local[v0, v1, v2]) T.writes(T_cast[v0, v1, v2]) T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2]