diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 71a69a000944..418d532fdd5f 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -42,17 +43,68 @@ #include namespace tvm { + +/*! + * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) + * \param mod The IRmodule to lower + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); + +/*! + * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list + * defined in CreatePassList) + * \param func The PrimFunc to lower + * \param name The name of the lowered function. + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, + bool simple_mode = false); + /*! - * \brief Build an IRModule given a schedule, args and binds - * \param sch The schedule to lower. + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. + * \param sch The TE schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds); +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + +/*! + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. + * \param sch The TE schedule to lower. + * \param args The arguments to the function (Array of Tensor, Buffer and Vars) + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + +/*! + * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want + * to apply lowering passes as well, use LowerSchedule. + * \param sch The schedule + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \return The result module. + */ +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds); /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index 32e74f6ef9d5..0ba7421ce409 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -88,18 +88,6 @@ bool VerifyCompactBuffer(const Stmt& stmt); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); -/*! - * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. - * - * \param stmt The stmt to be trasnformed. - * \param schedule The original schedule. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \return Transformed stmt. - */ -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer); - /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create * a PrimFunc that can then be used for further TIR optimizations. diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index dff0f098d84a..8d2591dce50b 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, binds) + binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py new file mode 100644 index 000000000000..c423656d78f5 --- /dev/null +++ b/python/tvm/driver/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.driver""" +import tvm._ffi + +tvm._ffi._init_api("driver", __name__) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a3d0bb656736..a4df63f225b2 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -37,96 +37,58 @@ from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from . import _ffi_api as ffi + def get_binds(args, compact=False, binds=None): """Internal function to get binds and arg_list given arguments. - Parameters ---------- args : list of Buffer or Tensor or Var The argument lists to the function. - compact : bool If the statement has already bound to a compact buffer. - binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. - Returns ------- binds: dict The bind specification - arg_list: list The list of symbolic buffers of arguments. """ - binds = {} if binds is None else binds.copy() - arg_list = [] - for x in args: - if isinstance(x, tensor.Tensor): - any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) - buffer_type = "auto_broadcast" if any_dim and not compact else "" - if x not in binds: - buf = tvm.tir.decl_buffer( - x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type - ) - binds[x] = buf - arg_list.append(buf) - else: - arg_list.append(binds[x]) - elif isinstance(x, schedule.Buffer): - arg_list.append(x) - elif isinstance(x, tvm.tir.Var): - arg_list.append(x) - else: - raise ValueError("args must be Tensor, Buffer or Var") + binds, arg_list = ffi.get_binds(args, compact, binds) return binds, arg_list -def form_irmodule(sch, args, name, binds): +def schedule_to_module( + sch: schedule.Schedule, + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: """According to the given schedule, form a function. - Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body - args : list of Buffer or Tensor or Var The argument lists to the function. - name : str - The name of result function. - + The name of result function, default name is "main" binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information - Returns ------- The body formed according to the given schedule """ - # normalize schedule first - pass_ctx = PassContext.current() - sch = sch.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - - compact = schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) - - stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) - func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - return tvm.IRModule({name: func}) + return ffi.schedule_to_module(sch, args, name, binds) def lower( - inputs: Union[schedule.Schedule, PrimFunc, IRModule], + inp: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -136,7 +98,7 @@ def lower( Parameters ---------- - input : Union[schedule.Schedule, PrimFunc, IRModule] + inputs : Union[schedule.Schedule, PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] @@ -160,90 +122,13 @@ def lower( m : IRModule The result IRModule """ - # config setup - pass_ctx = PassContext.current() - instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) - disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) - add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) - - lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] - lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] - lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] - lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - - # Phase 0 - pass_list = lower_phase0 - is_legacy_te_schedule: bool = False - - if isinstance(inputs, schedule.Schedule): - if args is None: - raise ValueError("args must be given for lowering from TE schedule") - mod = form_irmodule(inputs, args, name, binds) - is_legacy_te_schedule = True - elif isinstance(inputs, PrimFunc): - func = inputs.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) - elif isinstance(inputs, IRModule): - mod = inputs - else: - raise TypeError( - f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}" - ) - - # Phase 1 - if is_legacy_te_schedule: - pass_list += [ - tvm.tir.transform.InjectPrefetch(), - tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), - ] - else: - pass_list += [ - tvm.tir.transform.LowerInitBlock(), - tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), - tvm.tir.transform.ConvertBlocksToOpaque(), - tvm.tir.transform.CompactBufferAllocation(), - tvm.tir.transform.FlattenBuffer(), - ] - pass_list += [ - tvm.tir.transform.BF16Legalize(), - tvm.tir.transform.NarrowDataType(32), - tvm.tir.transform.Simplify(), - ] - - pass_list += lower_phase1 - - # Phase 2 - if not simple_mode: - pass_list += [(tvm.tir.transform.LoopPartition())] - - pass_list += [ - tvm.tir.transform.VectorizeLoop(not disable_vectorize), - tvm.tir.transform.InjectVirtualThread(), - tvm.tir.transform.InjectDoubleBuffer(), - tvm.tir.transform.StorageRewrite(), - tvm.tir.transform.UnrollLoop(), - ] - pass_list += lower_phase2 - - # Phase 3 - pass_list += [ - tvm.tir.transform.Simplify(), - tvm.tir.transform.RemoveNoOp(), - ] - - pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] - pass_list += [tvm.tir.transform.HoistIfThenElse()] - pass_list += lower_phase3 - - # Instrument BoundCheckers - if instrument_bound_checkers: - pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - - optimize = tvm.transform.Sequential(pass_list) - mod = optimize(mod) - return mod + if isinstance(inp, IRModule): + return ffi.lower_module(inp, simple_mode) + if isinstance(inp, PrimFunc): + return ffi.lower_primfunc(inp, name, simple_mode) + if isinstance(inp, schedule.Schedule): + return ffi.lower_schedule(inp, args, name, binds, simple_mode) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def _build_for_device(input_mod, target, target_host): diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 0cc26872cd47..071474a31594 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -38,7 +38,7 @@ @register_parser def add_compile_parser(subparsers): - """ Include parser for 'compile' subcommand """ + """Include parser for 'compile' subcommand""" parser = subparsers.add_parser("compile", help="compile a model.") parser.set_defaults(func=drive_compile) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 9460e23a5357..7378ed6beb8a 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -20,45 +20,6 @@ from tvm.target import Target -@tvm._ffi.register_func("relay.backend.lower") -def lower(sch, inputs, func_name, source_func): - """Backend function for lowering. - - Parameters - ---------- - sch : tvm.te.Schedule - The schedule. - - inputs : List[tvm.te.Tensor] - The inputs to the function. - - func_name : str - The name of the function. - - source-func : tvm.relay.Function - The source function to be lowered. - - Returns - ------- - mod : tvm.IRModule - The result of lowering. - """ - # pylint: disable=broad-except, import-outside-toplevel - import traceback - - try: - f = tvm.driver.lower(sch, inputs, name=func_name) - # logging.debug("lower function %s", func_name) - # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) - except Exception: - msg = traceback.format_exc() - msg += "Error during compile function\n" - msg += "-----------------------------\n" - msg += source_func.astext() - raise RuntimeError(msg) - return f - - @tvm._ffi.register_func("relay.backend.build") def build(mod, target, target_host=None): """Backend build function. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f30cecbf7f05..22e4bfc52796 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -93,22 +93,62 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } -void GetBinds(const Array& args, bool compact, +void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { *out_binds = binds; - for (const auto& x : args) { - if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); - out_binds->Set(x, buf); - out_arg_list->push_back(buf); + for (const ObjectRef& x : args) { + if (const te::TensorNode* tensor_node = x.as()) { + te::Tensor x_ref = GetRef(tensor_node); + if (out_binds->find(x_ref) == out_binds->end()) { + tir::Buffer buf = + BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + out_binds->Set(x_ref, buf); + out_arg_list->push_back(buf); + } else { + out_arg_list->push_back((*out_binds)[x_ref]); + } + } else if (x.as() || x.as()) { + out_arg_list->push_back(x); } else { - out_arg_list->push_back((*out_binds)[x]); + LOG(FATAL) + << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " + << "but got a " << x->GetTypeKey(); } } } +void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + GetBinds(ref_args, compact, binds, out_binds, out_arg_list); +} + +TVM_REGISTER_GLOBAL("driver.get_binds") + .set_body_typed([](const Array& args, bool compact, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + Map out_binds; + Array out_arg_list; + GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); + + // TVM object system doesn't have a pair object, so we'll put both ret values in an array + // and return that. + Array out_arr = {out_binds, out_arg_list}; + return out_arr; + }); + transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -128,63 +168,208 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { - Array out_arg_list; - auto pass_ctx = transform::PassContext::Current(); - - sch = sch.normalize(); - - // Before TIR transformation. - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); - bool compact = te::VerifyCompactBuffer(stmt); - - Map out_binds; - GetBinds(args, compact, binds, &out_binds, &out_arg_list); - - // build the function - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); +Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { + transform::PassContext pass_ctx = transform::PassContext::Current(); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + // Get any user-added passes + Array> add_lower_pass = + pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) + .value(); + + Array user_lower_phase0 = Array(); + Array user_lower_phase1 = Array(); + Array user_lower_phase2 = Array(); + Array user_lower_phase3 = Array(); + + // phase pasees is of the form + // [[phase_number, pass], [phase_number, pass]... ] + for (Array phase_pass : add_lower_pass) { + const IntImmNode* phase_num = phase_pass[0].as(); + ICHECK(phase_num) + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + int phase_num_val = phase_num->value; + + CHECK_GE(phase_num_val, 0); + + const tvm::transform::PassNode* pass_node = phase_pass[1].as(); + tvm::transform::Pass pass = GetRef(pass_node); + // Copy the pass into the correct phase + if (phase_num_val == 0) { + user_lower_phase0.push_back(pass); + } else if (phase_num_val == 1) { + user_lower_phase1.push_back(pass); + } else if (phase_num_val == 2) { + user_lower_phase2.push_back(pass); + } else if (phase_num_val >= 3) { + user_lower_phase3.push_back(pass); + } } - auto mod = IRModule(Map({{GlobalVar(name), f}})); - auto pass_list = Array(); + // Construct the pass list, inserting the user provided passes at the end of the phase + + // PHASE 0 + Array pass_list = user_lower_phase0; - // Phase 0 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - // Phase 1 + // PHASE 1 + if (for_te_schedule) { + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + } else { + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::FlattenBuffer()); + } pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::LoopPartition()); + + // Add user-defined phase-1 passes + pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); + + // PHASE 2 + if (!disable_loop_partition) { + pass_list.push_back(tir::transform::LoopPartition()); + } + pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::UnrollLoop()); - // Phase 2 + + // Add user-defined phase-2 passes + pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); + + // PHASE 3 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); + pass_list.push_back(tir::transform::HoistIfThenElse()); + + // Add user-defined phase-3 passes + pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); + if (instrument_bound_checkers) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } - // run - auto optimize = transform::Sequential(pass_list); + return pass_list; +} + +IRModule LowerWithPassList(IRModule mod, Array pass_list) { + auto optimize = tvm::transform::Sequential(pass_list); mod = optimize(std::move(mod)); return mod; } +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { + // Convert te schedule to IRModule + Array out_arg_list; + transform::PassContext pass_ctx = transform::PassContext::Current(); + + sch = sch.normalize(); + + // Before TIR transformation. + Map bounds = te::InferBound(sch); + tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); + bool compact = te::VerifyCompactBuffer(stmt); + + Map out_binds; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // Build the function + // At this point binds is only te::Tensors + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + return IRModule(Map({{GlobalVar(name), f}})); +} + +TVM_REGISTER_GLOBAL("driver.schedule_to_module") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds); + return mod; + }); + +IRModule LowerModule(IRModule mod, bool simple_mode) { + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { + return LowerModule(std::move(mod), simple_mode); +}); + +IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(name), f}})); + + // Get the pass list + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_primfunc") + .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { + return LowerPrimFunc(std::move(func), name, simple_mode); + }); + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + return LowerSchedule(std::move(sch), ref_args, name, binds); +} + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); + // Get the legacy TE pass list + Array pass_list = CreatePassList(simple_mode, true); + return LowerWithPassList(mod, pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_schedule") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds, bool simple_mode) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); + }); + std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 5e3b66b3ae15..5146c90f3bac 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -763,15 +763,9 @@ class CompileEngineImpl : public CompileEngineNode { all_args.push_back(arg); } // lower the function - if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); - } else { - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + std::unordered_map binds; + cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); - std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); - } value->cached_func = CachedFunc(cache_node); return value; } @@ -807,7 +801,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; } diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc deleted file mode 100644 index 951bd6c18706..000000000000 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ /dev/null @@ -1,1124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file schedule_postproc_rewrite_for_tensor_core.cc - * - * \brief Rewrite the Stmt generated by ScheduleOps - * to accomondate tensorcore. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace te { - -using namespace te; -using runtime::StorageRank; -using runtime::StorageScope; -using runtime::ThreadScope; - -struct Tile { - int m{-1}; - int n{-1}; - int k{-1}; -}; - -std::string simplify_name(std::string input) { - auto pos = input.find("."); - if (pos != std::string::npos) { - return input.substr(0, pos); - } else { - return input; - } -} - -PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { - auto cast = input.as(); - if (cast == nullptr) { - return input; - } else if (cast->dtype == target_type) { - return cast->value; - } - return PrimExpr(); -} - -// MMAMatcher matches C = Cast(A)*Cast(B)+C, -// where A & B are fp16/int8 local buffers, -// and C is fp32/int32 local buffer. -class MMAMatcher : public StmtVisitor { - public: - explicit MMAMatcher(Map extern_buffer) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::pragma_tensor_core) { - tensor_core_on_ = true; - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else { - StmtVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtVisitor::VisitStmt_(op); - auto it = buf_map_.find(Downcast(op->producer)); - if (it == buf_map_.end()) { - return; - } - const BufferInfo& bi = it->second; - if (bi.released) { - return; - } - if (tensor_core_on_ && mma_sync_match_(op, bi)) { - matched_ = true; - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - if (!buf_map_.at(key).external) { - return; - } - this->VisitStmt(op->body); - } else { - BufferInfo bi; - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - inline bool Matched() const { return matched_; } - - friend class ScheduleAnalyser; - friend class BufferAnalyser; - - private: - struct BufferInfo { - std::string name; - DataType dtype; - bool external{false}; - bool released{false}; - bool same_as(const BufferInfo& bi) { - if (this->dtype != bi.dtype) return false; - if (this->name != bi.name) return false; - if (this->external != bi.external) return false; - if (this->released != bi.released) return false; - return true; - } - }; - - // Check whether the storage scope is local - bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) { - auto tensor = Downcast(op->producer); - auto it = storage_scope_.find(tensor.get()); - if (it == storage_scope_.end()) { - return false; - } - const std::string& strkey = it->second; - if (strkey != "local") { - return false; - } - auto it1 = buf_map_.find(tensor); - if (it1 == buf_map_.end()) { - return false; - } - *bi = it1->second; - if (bi->released) { - return false; - } - return true; - } - - // Do the pattern matching - bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) { - auto* add = op->value.as(); - if (add == nullptr) { - return false; - } - - auto* load_c = add->a.as(); - BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || - !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { - return false; - } - - auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); - if (mul == nullptr) { - return false; - } - - auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); - auto load_a = load_a_expr.as(); - BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) || - !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); - auto load_b = load_b_expr.as(); - BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) || - !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - frag_reg_.insert(buffer_c.name); - frag_reg_.insert(buffer_a.name); - frag_reg_.insert(buffer_b.name); - buf_name_.insert(std::make_pair(load_a, buffer_a.name)); - buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); - - return true; - } - - std::unordered_map buf_map_; - std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; - std::unordered_set frag_reg_; - bool matched_{false}; - bool tensor_core_on_{false}; -}; - -// BodyVisitor visits the body stmt of original ComputeOp -// to get the access indices of input matrices, -// if it is recognized as matrix multiply. -class BodyVisitor : public StmtExprVisitor { - public: - BodyVisitor() {} - - void VisitExpr_(const ReduceNode* op) final { - auto* comm_add = op->combiner->result[0].as(); - if (comm_add == nullptr || op->combiner->result.size() > 1) { - return; - } - for (PrimExpr source : op->source) { - auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); - auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); - if (mul_0 == nullptr && mul_1 == nullptr) { - continue; - } - - tensorcore_candidate_ = true; - StmtExprVisitor::VisitExpr(source); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); - } - - friend class ScheduleAnalyser; - - private: - std::unordered_map> args_; - bool tensorcore_candidate_{false}; -}; - -// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major -class ScheduleAnalyser { - public: - explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} - - bool MatrixIdentify(Schedule schedule) { - // TODO(minmin): handle the case where MatMul is not the output stage - for (Operation output : schedule->outputs) { - const ComputeOpNode* compute = output.as(); - if (compute == nullptr) { - // Not a ComputeOp - continue; - } - auto axis = compute->axis; - auto reduce_axis = compute->reduce_axis; - if (axis.size() < 2 || reduce_axis.size() != 1) { - continue; - } - const VarNode* axis_var[2]; - const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size() - 2]->var.as(); - axis_var[1] = axis[axis.size() - 1]->var.as(); - reduce_axis_var = reduce_axis[0]->var.as(); - - BodyVisitor body_visitor; - for (PrimExpr expr : compute->body) { - body_visitor(expr); - } - if (!body_visitor.tensorcore_candidate_) { - continue; - } - for (auto iter : body_visitor.args_) { - auto name = iter.first; - auto args = iter.second; - if (args.size() < 2) { - continue; - } - const VarNode* var0 = args[args.size() - 2].as(); - const VarNode* var1 = args[args.size() - 1].as(); - if (var0 == nullptr || var1 == nullptr) { - continue; - } - std::string matrix_abc, major; - if (var0 == reduce_axis_var && var1 == axis_var[1]) { - matrix_abc = "matrix_a"; - major = "col_major"; - } else if (var0 == reduce_axis_var && var1 == axis_var[0]) { - matrix_abc = "matrix_b"; - major = "row_major"; - } else if (var0 == axis_var[1] && var1 == reduce_axis_var) { - matrix_abc = "matrix_a"; - major = "row_major"; - } else if (var0 == axis_var[0] && var1 == reduce_axis_var) { - matrix_abc = "matrix_b"; - major = "col_major"; - } - matrix_abc_.insert(std::make_pair(name, matrix_abc)); - matrix_major_.insert(std::make_pair(name, major)); - } - matrix_abc_.insert(std::make_pair(compute->name, "accumulator")); - matrix_major_.insert(std::make_pair(compute->name, "col_major")); - } - - for (auto& mma_sync : mma_sync_) { - auto& operands = mma_sync.second; - auto* load_a = operands[0].as(); - auto* load_b = operands[1].as(); - auto input0 = simplify_name(buf_name_.find(load_a)->second); - auto input1 = simplify_name(buf_name_.find(load_b)->second); - auto it0 = matrix_abc_.find(input0); - auto it1 = matrix_abc_.find(input1); - - if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) { - return false; - } - if (it0->second == "matrix_a" && it1->second == "matrix_b") { - return true; - } else if (it0->second == "matrix_b" && it1->second == "matrix_a") { - mma_sync.second = Array{operands[1], operands[0], operands[2]}; - } else { - return false; - } - } - return true; - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; -}; - -// IndexVisitor visits access index of fragment -// to record variable for loop scaling -class IndexVisitor : public StmtExprVisitor { - public: - IndexVisitor() {} - - void VisitExpr_(const VarNode* op) final { - loop_scaling_.insert(std::make_pair(op, scaling_factor_)); - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map loop_scaling_; - unsigned scaling_factor_{0}; -}; - -// BufferAnalyser gets buffer info, -// e.g. thread tile and warp tile, for TensorCore CodeGen -class BufferAnalyser : public StmtExprVisitor { - public: - explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - frag_reg_(mma_matcher.frag_reg_) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.strides = kv.second->strides; - bi.shape = kv.second->shape; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { - if (const IntImmNode* value = op->value.as()) { - thread_extent_.insert( - std::make_pair(op->node.as()->var->name_hint, value->value)); - } - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else if (op->attr_key == tir::attr::buffer_dim_align) { - te::Tensor tensor = Downcast(op->node); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto& vinfo = dim_align_[tensor]; - size_t dim = tuple->args[0].as()->value; - if (dim >= vinfo.size()) { - vinfo.resize(dim + 1); - } - vinfo[dim].align_factor = tuple->args[1].as()->value; - vinfo[dim].align_offset = tuple->args[2].as()->value; - this->VisitStmt(op->body); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtExprVisitor::VisitStmt_(op); - auto key = Downcast(op->producer); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(key->GetNameHint())) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(key->GetNameHint(), strides)); - - if (frag_reg_.count(bi.name)) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_load_.insert(std::make_pair(op, dst)); - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - std::vector tile_size; - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - tile_size.push_back(shape->value); - index_visitor.scaling_factor_ = shape->value; - } else { - invalid_ = true; - return; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - - std::string input_name = simplify_name(bi.name); - auto it = matrix_abc_.find(input_name); - auto it2 = matrix_major_.find(input_name); - bool ret = true; - if (it != matrix_abc_.end() && it2 != matrix_major_.end()) { - if (it->second == "matrix_a" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.m, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.n, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "accumulator") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (!ret) { - invalid_ = true; - return; - } - } - } - - const ProducerLoadNode* value = op->value.as(); - // TODO(tvm-team): string matching is dangerous, consider other means. - if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_store_.insert(std::make_pair(op, dst)); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - - auto tensor = Downcast(op->producer); - auto it = buf_map_.find(tensor); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(tensor->op->name)) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(tensor->GetNameHint(), strides)); - - if (!frag_reg_.count(bi.name)) { - return; - } - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - index_visitor.scaling_factor_ = shape->value; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external); - this->VisitStmt(op->body); - } else { - // create a buffer entry - BufferInfo bi; - - bi.bounds = op->bounds; - Array shape; - for (auto r : bi.bounds) { - shape.push_back(r->extent); - } - - Array strides; - if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; - const std::vector& avec = dim_align_[key]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = analyzer_.Simplify(stride); - } - rstrides.push_back(stride); - stride = stride * shape[dim]; - } - strides = Array(rstrides.rbegin(), rstrides.rend()); - } - - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - bi.strides = strides; - bi.shape = shape; - - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - // Derive warp tile from thread tile, - // and check whether it is qualified for TensorCore. - bool QualifiedForTensorCore() { - if (invalid_) { - return false; - } - auto itx = thread_extent_.find("threadIdx.x"); - if (itx == thread_extent_.end()) { - return false; - } - int warp_threads_x = itx->second; - warp_tile_.m = warp_threads_x * thread_tile_.m; - warp_threads_y_ = 32 / warp_threads_x; - auto ity = thread_extent_.find("threadIdx.y"); - if (ity == thread_extent_.end()) { - return false; - } - if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) { - return false; - } - warp_tile_.n = warp_threads_y_ * thread_tile_.n; - warp_tile_.k = thread_tile_.k; - return supported_warp_tile_(); - } - - friend class TensorCoreIRMutator; - - private: - struct DimAlignInfo { - int align_factor{0}; - int align_offset{0}; - }; - - struct BufferInfo { - std::string name; - DataType dtype; - Array strides; - Array shape; - Region bounds; - bool external{false}; - bool released{false}; - inline Array RelIndex(Array args) const { - if (bounds.size() != 0) { - Array index; - ICHECK_EQ(bounds.size(), args.size()); - for (size_t i = 0; i < bounds.size(); ++i) { - index.push_back(args[i] - bounds[i]->min); - } - return index; - } else { - return args; - } - } - }; - - bool assign_or_check_(int* dst, int src) { - if (*dst <= 0) { - *dst = src; - return true; - } - if (*dst == src) { - return true; - } - return false; - } - - bool supported_warp_tile_() { - if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { - return true; - } - - return false; - } - - std::unordered_map buf_map_; - std::unordered_map> dim_align_; - std::unordered_map storage_scope_; - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_set frag_reg_; - std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map thread_extent_; - IndexVisitor index_visitor; - Tile warp_tile_; - Tile thread_tile_; - arith::Analyzer analyzer_; - int warp_threads_y_{-1}; - bool invalid_{false}; -}; - -// ThreadIdxMutator does the thread index unification inside a warp -class ThreadIdxMutator : public StmtExprMutator { - public: - explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr) { - if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImm(DataType::Int(32), 0); - return zero; - } - if (op->name_hint == "threadIdx.y") { - PrimExpr div = Div(expr, warp_y_); - PrimExpr mul = Mul(div, warp_y_); - return mul; - } - } - return expr; - } - - private: - PrimExpr warp_y_; -}; - -// TensorCoreIRMutator mutates the AST for TensorCore CodeGen -// based on tensor core intrinsics -class TensorCoreIRMutator : public StmtExprMutator { - public: - explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, - const BufferAnalyser& buffer_analyser) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} - - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - bounds_[key] = op->bounds; - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - if (!frag_reg_.count(key->GetNameHint())) { - return stmt; - } - - auto new_extents = get_tile_size_(simplify_name(key->GetNameHint())); - - Region new_bounds; - for (size_t i = 0; i < op->bounds.size() - 2; ++i) { - new_bounds.push_back(op->bounds[i]); - } - ICHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint(); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return ProducerRealize(op->producer, new_bounds, op->condition, op->body); - } - return stmt; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (op->attr_key == tir::attr::realize_scope) { - auto node = op->node.as(); - if (node != nullptr) { - if (!frag_reg_.count(node->name)) { - return stmt; - } - - auto it = matrix_abc_.find(simplify_name(node->name)); - ICHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); - Stmt body = this->VisitStmt(op->body); - return AttrStmt(op->node, op->attr_key, matrix_abc, body); - } - } - return stmt; - } - - Stmt VisitStmt_(const ProducerStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - auto it = mma_sync_.find(op); - if (it != mma_sync_.end()) { - const auto& operands = it->second; - PrimExpr a = operands[0]; - auto ca = a.as(); - PrimExpr b = operands[1]; - auto cb = b.as(); - PrimExpr c = operands[2]; - auto cc = c.as(); - - ObjectPtr buffer_node_a = make_object(); - ObjectPtr buffer_node_b = make_object(); - ObjectPtr buffer_node_c = make_object(); - - auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_bmma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } else { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_mma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } - }; - - auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call); - }; - - auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c); - }; - - return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b); - } - - auto it2 = frag_load_.find(op); - if (it2 != frag_load_.end()) { - PrimExpr dst = it2->second; - if (op->value.as() != nullptr || op->value.as() != nullptr) { - auto pload = dst.as(); - - auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); - } - - const ProducerLoadNode* value = op->value.as(); - ICHECK(value != nullptr) << "Can only load fragment from a buffer"; - - auto it = strides_.find(value->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - PrimExpr mutated_value = thread_idx_mutator(op->value); - // TODO(tvm-team) The extern function name seems to be a hack. - PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}); - - auto pload = dst.as(); - PrimExpr matrix_major; - auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint())); - ICHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << pload->producer->GetNameHint(); - if (iter2->second == "col_major") { - matrix_major = StringImm("col_major"); - } else if (iter2->second == "row_major") { - matrix_major = StringImm("row_major"); - } else { - LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); - } - - auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call); - } - - auto it3 = frag_store_.find(op); - if (it3 != frag_store_.end()) { - auto it = strides_.find(op->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - PrimExpr dst = it3->second; - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - dst = thread_idx_mutator(dst); - dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}); - - auto pload = op->value.as(); - - auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImm("col_major")})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call); - } - - return stmt; - } - - Stmt VisitStmt_(const ForNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - auto it = loop_scaling_.find(op->loop_var.get()); - if (it != loop_scaling_.end()) { - int scale_factor = it->second; - int scaled_extent_value = 1; - if (const IntImmNode* ori_extent = op->extent.as()) { - int ori_extent_value = ori_extent->value; - scaled_extent_value = ori_extent_value / scale_factor; - } - PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, - op->annotations); - } - } - return stmt; - } - - private: - Array get_tile_size_(const std::string& name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - ICHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; - } - - Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, - const ObjectPtr& buffer_node, - const std::function& call_back) { - auto tensor = Downcast(pload->producer); - auto it = bounds_.find(tensor); - ICHECK(it != bounds_.end()); - Array min_bound; - for (auto i : it->second) { - min_bound.push_back(i->min); - } - - ICHECK_GE(it->second.size(), 2); - Array shape; - for (size_t i = 0; i < it->second.size() - 2; ++i) { - shape.push_back(it->second[i]->extent); - } - auto tile_size = get_tile_size_(simplify_name(tensor->op->name)); - shape.push_back(tile_size[0]); - shape.push_back(tile_size[1]); - - Array strides; - for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = shape.size() - 1; j >= i; --j) { - stride = Mul(stride, shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - - PrimExpr elem_offset = IntImm(DataType::Int(32), 0); - ICHECK_EQ(pload->indices.size(), min_bound.size()); - for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i]))); - } - - auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); - ICHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name; - buffer_node->data = Var(tensor->op->name, DataType::Handle()); - buffer_node->name = tensor->op->name; - buffer_node->scope = "wmma." + it2->second; - buffer_node->dtype = tensor->dtype; - buffer_node->strides = strides; - buffer_node->shape = shape; - buffer_node->data_alignment = 1; - buffer_node->elem_offset = analyzer_.Simplify(elem_offset); - buffer_node->offset_factor = 1; - Buffer buffer(buffer_node); - - Array args; - for (size_t i = 0; i < pload->indices.size(); ++i) { - args.push_back(pload->indices[i]); - args.push_back(shape[i]); - } - auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args); - Array node = {buffer, tensor}; - return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); - } - - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map> strides_; - std::unordered_set frag_reg_; - std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map bounds_; - arith::Analyzer analyzer_; - Tile warp_tile_; - int warp_threads_y_{-1}; -}; - -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer) { - // Check if current lower target is CUDA - auto target = tvm::Target::Current(true); - if (target.defined() && target->kind->name != "cuda") { - return stmt; - } - - // Check if current runtime support GPU CUDA - Device dev{kDLCUDA, 0}; - auto api = tvm::runtime::DeviceAPI::Get(dev, true); - if (api == nullptr) { - return stmt; - } - - MMAMatcher mma_matcher(extern_buffer); - mma_matcher(stmt); - if (!mma_matcher.Matched()) { - return stmt; - } - - ScheduleAnalyser schedule_analyser(mma_matcher); - if (!schedule_analyser.MatrixIdentify(schedule)) { - return stmt; - } - - BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); - buffer_analyser(stmt); - if (!buffer_analyser.QualifiedForTensorCore()) { - return stmt; - } - - return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); -} - -TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") - .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { - return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); - }); - -} // namespace te -} // namespace tvm diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 8cc5c4bc0a3a..204a824f9248 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = lower(s, args, "func", binds); + auto lowered = LowerSchedule(s, args, "func", binds); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -116,8 +116,8 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = lower(s1, args1, "elemwise_add", binds); - auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 40d42a28025a..fe31a753746c 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -63,7 +63,7 @@ def run_model_graph(TestClass): def test_add_one(): class AddOne(tf.Module): - """ simple function to test x=x+1; scalar as input""" + """simple function to test x=x+1; scalar as input""" def get_input(self): return np.array(1.0, dtype="float32") diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index c5163a8457af..1e3c8061e029 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -20,7 +20,7 @@ def lower_stmt(sche, params, passfunc): - func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(sche, params, "main", None)["main"] func = passfunc()(tvm.IRModule.from_expr(func))["main"] stmt = func.body return stmt @@ -42,7 +42,7 @@ def get_promoted(op): lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], "float")), "bfloat16"), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body def test_promoted(op): @@ -111,7 +111,7 @@ def get_target(): ), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body tvm.ir.assert_structural_equal(get_eliminated(), get_target()) @@ -151,7 +151,7 @@ def check(fcompute_before, fcompute_after): b = te.placeholder((100,), dtype="uint16", name="B") c = te.compute((100,), fcompute_after(a, b), name="C") s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] tvm.ir.assert_structural_equal(stmt, func.body) def orig1(a, b): diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 7d02e4f12c1d..252a187dbdc5 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -522,7 +522,7 @@ def test_hoisting_block_scope_1(): s[B.op].bind(xi, te.thread_axis("threadIdx.y")) s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) - func = tvm.driver.build_module.form_irmodule(s, [A, B], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) @@ -622,7 +622,7 @@ def test_hoisting_block_scope_4(): s[C].pragma(xo2, "parallel_stride_pattern") s[C].pragma(xo2, "parallel_barrier_when_finish") s[C].vectorize(xi) - func = tvm.driver.build_module.form_irmodule(s, [A, B, C], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt)