From 2e0d65f83db0d256f4d3860a1feac69e75d7ae88 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 14 May 2021 20:16:55 +0800 Subject: [PATCH 1/4] build tir --- include/tvm/tir/analysis.h | 5 +- python/tvm/contrib/nvcc.py | 2 +- python/tvm/driver/build_module.py | 94 ++++++++++---- .../schedule/schedule_postproc_to_primfunc.cc | 4 +- .../analysis/buffer_access_lca_detector.cc | 14 ++- .../plan_update_buffer_allocation_location.cc | 4 +- tests/python/unittest/test_lower_build.py | 117 ++++++++++++++++++ ...t_tir_analysis_detect_buffer_access_lca.py | 14 +++ 8 files changed, 216 insertions(+), 38 deletions(-) create mode 100644 tests/python/unittest/test_lower_build.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index c2b3148e5eb9..da620abc71c1 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -189,9 +189,10 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func); * access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). * The LCA may be a For loop or a Block. * \param func The PrimFunc to be detected. - * \return The Map from buffer to the LCA of all access to it. + * \return The Map from buffer to the LCA of all access to it. The lca is function root if the + * return stmt is NullOpt. */ -TVM_DLL Map DetectBufferAccessLCA(const PrimFunc& func); +TVM_DLL Map> DetectBufferAccessLCA(const PrimFunc& func); // Pass variants of verification analysis // directly throws RuntimeError when verification fails. diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 30b5e6dffdc2..612be292e873 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -194,7 +194,7 @@ def find_libdevice_path(arch): selected_ver = 0 selected_path = None cuda_ver = get_cuda_version(cuda_path) - if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2): + if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2, 11.3): path = os.path.join(lib_path, "libdevice.10.bc") else: for fn in os.listdir(lib_path): diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4682e344461d..0ace3682df23 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -18,6 +18,8 @@ # pylint: disable=invalid-name """The build utils in python. """ + +from typing import Union, Optional, List, Mapping import warnings import tvm.tir @@ -25,11 +27,15 @@ from tvm.runtime import ndarray from tvm.ir import container from tvm.ir import CallingConv +from tvm.tir import PrimFunc +from tvm.ir.module import IRModule from tvm.ir.transform import PassContext from tvm.target import codegen from tvm.te import tensor from tvm.te import schedule from tvm.target import Target +from tvm.tir.buffer import Buffer +from tvm.tir.expr import Var def get_binds(args, compact=False, binds=None): @@ -119,32 +125,39 @@ def form_irmodule(sch, args, name, binds): return tvm.IRModule({name: func}) -def lower(sch, args, name="main", binds=None, simple_mode=False): +def lower( + inputs: Union[schedule.Schedule, PrimFunc, IRModule], + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, + simple_mode: bool = False, +) -> IRModule: """Lowering step before build into target. Parameters ---------- - sch : tvm.te.schedule.Schedule - The schedule to be built + input : Union[schedule.Schedule, PrimFunc, IRModule] + The TE schedule or TensorIR PrimFunc/IRModule to be built - args : list of Buffer or Tensor or Var - The argument lists to the function. + args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] + The argument lists to the function for TE schedule. + It should be None if we want to lower TensorIR. - name : str, optional + name : str The name of result function. - binds : dict of :any:`Tensor` to :any:`Buffer`, optional + binds : Optional[Mapping[tensor.Tensor, Buffer]] Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. - simple_mode : bool, optional + simple_mode : bool Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. Returns ------- - m : IRModule or Stmt + m : IRModule The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ @@ -160,16 +173,38 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] # Phase 0 - if isinstance(sch, schedule.Schedule): - mod = form_irmodule(sch, args, name, binds) + pass_list = lower_phase0 + is_legacy_te_schedule: bool = False + + if isinstance(inputs, schedule.Schedule): + if args is None: + raise ValueError("args must be given for lowering from TE schedule") + mod = form_irmodule(inputs, args, name, binds) + is_legacy_te_schedule = True + elif isinstance(inputs, PrimFunc): + func = inputs.with_attr("global_symbol", name) + if pass_ctx.config.get("tir.noalias", True): + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) + elif isinstance(inputs, IRModule): + mod = inputs else: - mod = sch + raise TypeError( + f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}" + ) - pass_list = lower_phase0 # Phase 1 + if is_legacy_te_schedule: + pass_list += [ + tvm.tir.transform.InjectPrefetch(), + tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), + ] pass_list += [ - tvm.tir.transform.InjectPrefetch(), - tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), + tvm.tir.transform.LowerInitBlock(), + tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), + tvm.tir.transform.ConvertBlocksToOpaque(), + tvm.tir.transform.CompactBufferAllocation(), + tvm.tir.transform.FlattenBuffer(), tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), @@ -297,22 +332,29 @@ def _build_for_device(input_mod, target, target_host): return mod_host, rt_mod_dev -def build(inputs, args=None, target=None, target_host=None, name="default_function", binds=None): +def build( + inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + target: Optional[Union[str, Target]] = None, + target_host: Optional[Union[str, Target]] = None, + name: Optional[str] = "default_function", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +): """Build a function with arguments as signature. Code will be generated for devices coupled with target information. Parameters ---------- - inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule - The schedule to be built + inputs : Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]] + The input to be built - args : list of Buffer or Tensor or Var, optional + args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] The argument lists to the function. - target : str or :any:`tvm.target.Target`, optional + target : Optional[Union[str, Target]] The target and option of the compilation. - target_host : str or :any:`tvm.target.Target` optional + target_host : Optional[Union[str, Target]] Host compilation target, if target is device. When TVM compiles device specific program such as CUDA, we also need host(CPU) side code to interact with the driver @@ -321,10 +363,10 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. - name : str, optional + name : Optional[str] The name of result function. - binds : dict, optional + binds : Optional[Mapping[tensor.Tensor, Buffer]] Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. @@ -375,10 +417,10 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi elif isinstance(inputs, (list, tuple, container.Array)): merged_mod = tvm.IRModule({}) for x in inputs: - merged_mod.update(x) + merged_mod.update(lower(x)) input_mod = merged_mod - elif isinstance(inputs, tvm.IRModule): - input_mod = inputs + elif isinstance(inputs, (tvm.IRModule, PrimFunc)): + input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( f"Inputs must be Schedule, IRModule or dict of target to IRModule, " diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 1710a91c6985..32cc51039be0 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -159,13 +159,13 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, ICHECK(!extern_buffer.count(tensor)); tir::Buffer buffer = CreateBufferFor(tensor); - tir::Var bptr(buffer->name, DataType::Handle()); + tir::Var bptr(buffer->name, PrimType(DataType::Handle())); params.push_back(bptr); buffer_map.Set(bptr, buffer); extern_buffer[tensor] = buffer; } else { tir::Buffer buffer = Downcast(var); - tir::Var bptr(buffer->name, DataType::Handle()); + tir::Var bptr(buffer->name, PrimType(DataType::Handle())); params.push_back(bptr); buffer_map.Set(bptr, buffer); } diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 23e60e16fc62..6f2622f3a61e 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -36,17 +36,20 @@ namespace tir { */ class LCADetector : public StmtExprVisitor { public: - static Map Detect(const PrimFunc& func) { + static Map> Detect(const PrimFunc& func) { LCADetector detector; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); } + detector(func->body); // Prepare the return - Map buffer_lca; + Map> buffer_lca; for (const auto& kv : detector.buffer_lca_) { - buffer_lca.Set(GetRef(kv.first), GetRef(kv.second->stmt)); + const Buffer& buffer = GetRef(kv.first); + const Optional stmt = kv.second ? GetRef>(kv.second->stmt) : NullOpt; + buffer_lca.Set(buffer, stmt); } return buffer_lca; } @@ -131,7 +134,6 @@ class LCADetector : public StmtExprVisitor { } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { - ICHECK(lhs || rhs); if (lhs == nullptr) return rhs; if (rhs == nullptr) return lhs; while (lhs->parent_scope_info != nullptr && // @@ -166,7 +168,9 @@ class LCADetector : public StmtExprVisitor { support::Arena arena_; }; -Map DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } +Map> DetectBufferAccessLCA(const PrimFunc& func) { + return LCADetector::Detect(func); +} TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA); } // namespace tir diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index ecedaa64d7df..2a7b984a20c2 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -32,7 +32,7 @@ namespace tir { class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { - Map buffer_lca = DetectBufferAccessLCA(func); + Map> buffer_lca = DetectBufferAccessLCA(func); std::unordered_set arg_buffers; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -42,7 +42,7 @@ class BufferAllocationLocator : public StmtExprMutator { // create buffers to be allocated at each stmts for (const auto& kv : buffer_lca) { const Buffer& buffer = kv.first; - const StmtNode* stmt = kv.second.get(); + const StmtNode* stmt = kv.second.defined()? kv.second.value().get() : nullptr; if (arg_buffers.count(buffer.get())) { continue; } diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py new file mode 100644 index 000000000000..21f41321887f --- /dev/null +++ b/tests/python/unittest/test_lower_build.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +from tvm import te, tir +from tvm.ir.module import IRModule +from tvm.script import ty +import tvm.testing + + +def _check_module_with_numpy(mod, shape=(128, 128, 128)): + m, n, k = shape + a = tvm.nd.array(np.random.rand(m, k).astype("float32")) + b = tvm.nd.array(np.random.rand(n, k).astype("float32")) + c = tvm.nd.array(np.zeros((m, n), dtype="float32")) + c_np = np.dot(a.asnumpy(), b.asnumpy().transpose()) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + +# pylint: disable=no-self-argument, missing-class-docstring, missing-function-docstring +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +class LoweredModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + # body + for x, y in tir.grid(128, 128): + C.data[x * 128 + y] = 0.0 + for k in tir.serial(0, 128): + C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load( + "float32", A.data, x * 128 + k + ) * tir.load("float32", B.data, y * 128 + k) + + +def test_lower_build_te_schedule(): + m, n, k = 128, 128, 128 + axis_k = te.reduce_axis((0, k), "k") + A = te.placeholder((m, k), name="A") + B = te.placeholder((k, n), name="B") + C = te.compute((m, n), lambda x, y: te.sum(A[x, axis_k] * B[y, axis_k], axis=axis_k), name="C") + s = te.create_schedule(C.op) + # check lowering + ir_mod = tvm.lower(s, [A, B, C]) + tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + # check building + mod = tvm.build(s, [A, B, C], target="llvm") + _check_module_with_numpy(mod) + + +def test_lower_build_tir_func(): + # check lowering + ir_mod = tvm.lower(matmul) + tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + # check building + mod = tvm.build(matmul, target="llvm") + _check_module_with_numpy(mod) + + +def test_lower_build_tir_module(): + func = matmul.with_attr("global_symbol", "main") + func = func.with_attr("tir.noalias", True) + ir_mod = IRModule({"main": func}) + # check lowering + lowered_mod = tvm.lower(ir_mod) + tvm.ir.assert_structural_equal(lowered_mod, LoweredModule()) + # check building + mod = tvm.build(ir_mod, target="llvm") + _check_module_with_numpy(mod) + + +def test_lower_build_lowered_module(): + # check lowering + ir_mod = tvm.lower(LoweredModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + # check building + mod = tvm.build(ir_mod, target="llvm") + _check_module_with_numpy(mod) + + +if __name__ == "__main__": + test_lower_build_te_schedule() + test_lower_build_tir_func() + test_lower_build_tir_module() + test_lower_build_lowered_module() diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 7ac61a705fbd..36fd80fd07de 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -64,6 +64,12 @@ def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None: C[vi, vj] = B[vi, vj] +@tvm.script.tir +def lca_is_func_root(a: ty.handle) -> None: + A = tir.match_buffer(a, [0, 0], "float32") + A.data[0] = 1.0 + + def test_buffer_load_store(): func = buffer_load_store_func A, B = [func.buffer_map[x] for x in func.params] @@ -102,6 +108,14 @@ def test_opaque_access(): assert lca[C] == root_block.body[1].body.body.block +def test_lca_func_root(): + func = lca_is_func_root + (A,) = [func.buffer_map[x] for x in func.params] + lca = tir.analysis.detect_buffer_access_lca(func) + assert lca[A] is None + + if __name__ == "__main__": test_buffer_load_store() test_opaque_access() + test_lca_func_root() From ed6c424820258d743b7f0e721cd912620ad7ea9b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 14 May 2021 10:38:03 -0400 Subject: [PATCH 2/4] Update plan_update_buffer_allocation_location.cc --- src/tir/transforms/plan_update_buffer_allocation_location.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 2a7b984a20c2..949c955b2dfe 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -42,7 +42,7 @@ class BufferAllocationLocator : public StmtExprMutator { // create buffers to be allocated at each stmts for (const auto& kv : buffer_lca) { const Buffer& buffer = kv.first; - const StmtNode* stmt = kv.second.defined()? kv.second.value().get() : nullptr; + const StmtNode* stmt = kv.second.defined() ? kv.second.value().get() : nullptr; if (arg_buffers.count(buffer.get())) { continue; } From 4919ae1be267b99d030623487644541348221cb6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 15 May 2021 01:30:41 +0000 Subject: [PATCH 3/4] fix unit loop replace --- src/tir/transforms/flatten_buffer.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 82035cb7fdf5..07f7b42fe2eb 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -113,7 +113,11 @@ class BufferFlattener : public StmtExprMutator { if (it == unit_loop_vars_.end()) { return std::move(var); } else { - return it->second; + PrimExpr expr = it->second; + if (expr.dtype() != var.dtype()) { + expr = Cast(var.dtype(), std::move(expr)); + } + return expr; } } From 2036473745f3e955e12aa5ee860c563b75641a54 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 15 May 2021 15:17:35 +0000 Subject: [PATCH 4/4] address --- python/tvm/driver/build_module.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 0ace3682df23..a3d0bb656736 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -158,8 +158,7 @@ def lower( Returns ------- m : IRModule - The result IRModule, if simple_mode=False - Then the Stmt before make api is returned. + The result IRModule """ # config setup pass_ctx = PassContext.current() @@ -199,16 +198,20 @@ def lower( tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), ] + else: + pass_list += [ + tvm.tir.transform.LowerInitBlock(), + tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), + tvm.tir.transform.ConvertBlocksToOpaque(), + tvm.tir.transform.CompactBufferAllocation(), + tvm.tir.transform.FlattenBuffer(), + ] pass_list += [ - tvm.tir.transform.LowerInitBlock(), - tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), - tvm.tir.transform.ConvertBlocksToOpaque(), - tvm.tir.transform.CompactBufferAllocation(), - tvm.tir.transform.FlattenBuffer(), tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), ] + pass_list += lower_phase1 # Phase 2