From ae28f582c0838af9393d5fa1bd07c3c3d7840e27 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Fri, 7 May 2021 18:46:47 -0400 Subject: [PATCH 01/52] attempt of c->python --- python/tvm/driver/_ffi_api.py | 3 + python/tvm/driver/build_module.py | 119 ++++++++++++++++-------------- src/driver/driver_api.cc | 6 ++ 3 files changed, 71 insertions(+), 57 deletions(-) create mode 100644 python/tvm/driver/_ffi_api.py diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py new file mode 100644 index 000000000000..9a1a7f414ab1 --- /dev/null +++ b/python/tvm/driver/_ffi_api.py @@ -0,0 +1,3 @@ +import tvm._ffi + +tvm._ffi._init_api("lower",__name__) \ No newline at end of file diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4682e344461d..7913f402f4d2 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -31,6 +31,8 @@ from tvm.te import schedule from tvm.target import Target +from . import _ffi_api + def get_binds(args, compact=False, binds=None): """Internal function to get binds and arg_list given arguments. @@ -120,6 +122,9 @@ def form_irmodule(sch, args, name, binds): def lower(sch, args, name="main", binds=None, simple_mode=False): + + mod = self.__init_handler_by_constructor__(_ffi_api.lower, sch, args, name, binds) + """Lowering step before build into target. Parameters @@ -149,62 +154,62 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): Then the Stmt before make api is returned. """ # config setup - pass_ctx = PassContext.current() - instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) - disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) - add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) - - lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] - lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] - lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] - lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - - # Phase 0 - if isinstance(sch, schedule.Schedule): - mod = form_irmodule(sch, args, name, binds) - else: - mod = sch - - pass_list = lower_phase0 - # Phase 1 - pass_list += [ - tvm.tir.transform.InjectPrefetch(), - tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), - tvm.tir.transform.BF16Legalize(), - tvm.tir.transform.NarrowDataType(32), - tvm.tir.transform.Simplify(), - ] - pass_list += lower_phase1 - - # Phase 2 - if not simple_mode: - pass_list += [(tvm.tir.transform.LoopPartition())] - - pass_list += [ - tvm.tir.transform.VectorizeLoop(not disable_vectorize), - tvm.tir.transform.InjectVirtualThread(), - tvm.tir.transform.InjectDoubleBuffer(), - tvm.tir.transform.StorageRewrite(), - tvm.tir.transform.UnrollLoop(), - ] - pass_list += lower_phase2 - - # Phase 3 - pass_list += [ - tvm.tir.transform.Simplify(), - tvm.tir.transform.RemoveNoOp(), - ] - - pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] - pass_list += [tvm.tir.transform.HoistIfThenElse()] - pass_list += lower_phase3 - - # Instrument BoundCheckers - if instrument_bound_checkers: - pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - - optimize = tvm.transform.Sequential(pass_list) - mod = optimize(mod) + # pass_ctx = PassContext.current() + # instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) + # disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) + # add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) + + # lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] + # lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] + # lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] + # lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] + + # # Phase 0 + # if isinstance(sch, schedule.Schedule): + # mod = form_irmodule(sch, args, name, binds) + # else: + # mod = sch + + # pass_list = lower_phase0 + # # Phase 1 + # pass_list += [ + # tvm.tir.transform.InjectPrefetch(), + # tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), + # tvm.tir.transform.BF16Legalize(), + # tvm.tir.transform.NarrowDataType(32), + # tvm.tir.transform.Simplify(), + # ] + # pass_list += lower_phase1 + + # # Phase 2 + # if not simple_mode: + # pass_list += [(tvm.tir.transform.LoopPartition())] + + # pass_list += [ + # tvm.tir.transform.VectorizeLoop(not disable_vectorize), + # tvm.tir.transform.InjectVirtualThread(), + # tvm.tir.transform.InjectDoubleBuffer(), + # tvm.tir.transform.StorageRewrite(), + # tvm.tir.transform.UnrollLoop(), + # ] + # pass_list += lower_phase2 + + # # Phase 3 + # pass_list += [ + # tvm.tir.transform.Simplify(), + # tvm.tir.transform.RemoveNoOp(), + # ] + + # pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] + # pass_list += [tvm.tir.transform.HoistIfThenElse()] + # pass_list += lower_phase3 + + # # Instrument BoundCheckers + # if instrument_bound_checkers: + # pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] + + # optimize = tvm.transform.Sequential(pass_list) + # mod = optimize(mod) return mod @@ -436,7 +441,7 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi target_host = Target(target_host) if ( target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c" - and target_host.attrs.get("system-lib", 0) == 1 + and target_host.attrs.get("system-lib", 0).value == 1 ): if target_host.kind.name == "c": create_csource_crt_metadata_module = tvm._ffi.get_global_func( diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f30cecbf7f05..0d58b43b5b28 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -185,6 +185,11 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } +TVM_REGISTER_GLOBAL("lower").set_body_typed([](te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { + return lower(sch, args, name, binds); +}); + std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { @@ -339,3 +344,4 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, } } // namespace tvm + From 125822388cf944da9c86e699e07e7df402650098 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 10 May 2021 16:35:19 -0400 Subject: [PATCH 02/52] name change --- python/tvm/driver/_ffi_api.py | 2 +- src/driver/driver_api.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index 9a1a7f414ab1..8371a51190b1 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -1,3 +1,3 @@ import tvm._ffi -tvm._ffi._init_api("lower",__name__) \ No newline at end of file +tvm._ffi._init_api("tvm.driver",__name__) \ No newline at end of file diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0d58b43b5b28..0d04398645fa 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -185,7 +185,7 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } -TVM_REGISTER_GLOBAL("lower").set_body_typed([](te::Schedule sch, const Array& args, const std::string& name, +TVM_REGISTER_GLOBAL("tvm.driver.lower").set_body_typed([](te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { return lower(sch, args, name, binds); }); From f57ae0e2bc21b7a2f702bf22db8071f1f2577ac8 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Wed, 12 May 2021 17:44:26 -0400 Subject: [PATCH 03/52] build update --- python/tvm/driver/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 7913f402f4d2..b2532ef2c0c0 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -123,7 +123,7 @@ def form_irmodule(sch, args, name, binds): def lower(sch, args, name="main", binds=None, simple_mode=False): - mod = self.__init_handler_by_constructor__(_ffi_api.lower, sch, args, name, binds) + mod = self.__init_handler_by_constructor__(_ffi_api.tvm.driver.lower, sch, args, name, binds) """Lowering step before build into target. From 8a4f40e19ad985865ce297f2197e345b0b464bb4 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Wed, 19 May 2021 16:32:34 -0400 Subject: [PATCH 04/52] import fix --- python/tvm/driver/_ffi_api.py | 2 +- src/driver/driver_api.cc | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index 8371a51190b1..5606064487d3 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -1,3 +1,3 @@ import tvm._ffi -tvm._ffi._init_api("tvm.driver",__name__) \ No newline at end of file +tvm._ffi._init_api("tvm.driver", __name__) \ No newline at end of file diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0d04398645fa..342812ec3815 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -185,9 +185,12 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } -TVM_REGISTER_GLOBAL("tvm.driver.lower").set_body_typed([](te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { - return lower(sch, args, name, binds); +TVM_REGISTER_GLOBAL("tvm.driver.lower").set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds) { + std::unordered_map c_binds; + for (auto kv : binds) { + c_binds.insert(std::pair(kv.first, kv.second)); + } + return lower(sch, args, name, c_binds); }); std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, From f7e8bbd74fea992ce814996d08b156d1620348a2 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Thu, 20 May 2021 18:01:57 -0400 Subject: [PATCH 05/52] build working, import still needs work --- python/tvm/driver/build_module.py | 67 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index b2532ef2c0c0..04eb4b1872d9 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -31,7 +31,7 @@ from tvm.te import schedule from tvm.target import Target -from . import _ffi_api +from . import _ffi_api as ffi def get_binds(args, compact=False, binds=None): @@ -120,39 +120,46 @@ def form_irmodule(sch, args, name, binds): func = func.with_attr("tir.noalias", True) return tvm.IRModule({name: func}) +# bla = tvm.driver.lower(sch, fhdkslaje, ewakl fjelka) def lower(sch, args, name="main", binds=None, simple_mode=False): + return ffi.lower(sch, args, name, binds) - mod = self.__init_handler_by_constructor__(_ffi_api.tvm.driver.lower, sch, args, name, binds) - - """Lowering step before build into target. - - Parameters - ---------- - sch : tvm.te.schedule.Schedule - The schedule to be built +# def lower(sch, args, name="main", binds=None, simple_mode=False): +# return _ffi_api.lower(sch, args, name, binds) - args : list of Buffer or Tensor or Var - The argument lists to the function. - - name : str, optional - The name of result function. +# def lower(sch, args, name="main", binds=None, simple_mode=False): - binds : dict of :any:`Tensor` to :any:`Buffer`, optional - Dictionary that maps the Tensor to Buffer which specified the data layout - requirement of the function. By default, a new compact buffer is created - for each tensor in the argument. - - simple_mode : bool, optional - Whether only output simple and compact statement, this will skip - LoopPartition, api wrapper generation and Unrolling. - - Returns - ------- - m : IRModule or Stmt - The result IRModule, if simple_mode=False - Then the Stmt before make api is returned. - """ +# mod = self.__init_handler_by_constructor__(_ffi_api.tvm.driver.lower, sch, args, name, binds) + +# """Lowering step before build into target. + +# Parameters +# ---------- +# sch : tvm.te.schedule.Schedule +# The schedule to be built + +# args : list of Buffer or Tensor or Var +# The argument lists to the function. + +# name : str, optional +# The name of result function. + +# binds : dict of :any:`Tensor` to :any:`Buffer`, optional +# Dictionary that maps the Tensor to Buffer which specified the data layout +# requirement of the function. By default, a new compact buffer is created +# for each tensor in the argument. + +# simple_mode : bool, optional +# Whether only output simple and compact statement, this will skip +# LoopPartition, api wrapper generation and Unrolling. + +# Returns +# ------- +# m : IRModule or Stmt +# The result IRModule, if simple_mode=False +# Then the Stmt before make api is returned. +# """ # config setup # pass_ctx = PassContext.current() # instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) @@ -210,7 +217,7 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): # optimize = tvm.transform.Sequential(pass_list) # mod = optimize(mod) - return mod + # return mod def _build_for_device(input_mod, target, target_host): From e4267fb19a48a39388854394dfbc702d163ce803 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Fri, 21 May 2021 16:03:53 -0400 Subject: [PATCH 06/52] returning null binds in driver_api.cc --- python/tvm/driver/_ffi_api.py | 2 +- src/driver/driver_api.cc | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index 5606064487d3..3d72edfa70c6 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -1,3 +1,3 @@ import tvm._ffi -tvm._ffi._init_api("tvm.driver", __name__) \ No newline at end of file +tvm._ffi._init_api("driver", __name__) \ No newline at end of file diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 342812ec3815..182897ce264b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -185,10 +185,13 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } -TVM_REGISTER_GLOBAL("tvm.driver.lower").set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds) { +TVM_REGISTER_GLOBAL("driver.lower").set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds) { std::unordered_map c_binds; - for (auto kv : binds) { - c_binds.insert(std::pair(kv.first, kv.second)); + // Check to make sure binds is not null before doing hte conversion; + if (binds->count) { // TODO: figure out why this is not compiling C++ sucks + for (auto kv : binds) { + c_binds.insert(std::pair(kv.first, kv.second)); + } } return lower(sch, args, name, c_binds); }); From fa322f62f67116e32ee745f8b71770d3b9de9d43 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Fri, 21 May 2021 18:39:47 -0400 Subject: [PATCH 07/52] tests pass woohoo00! --- python/tvm/driver/_ffi_api.py | 20 ++++- python/tvm/driver/build_module.py | 135 ------------------------------ src/driver/driver_api.cc | 32 ++++--- 3 files changed, 39 insertions(+), 148 deletions(-) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index 3d72edfa70c6..eb078801008a 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -1,3 +1,21 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.driver""" import tvm._ffi -tvm._ffi._init_api("driver", __name__) \ No newline at end of file +tvm._ffi._init_api("driver", __name__) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 04eb4b1872d9..507224f334f0 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -81,144 +81,9 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def form_irmodule(sch, args, name, binds): - """According to the given schedule, form a function. - - Parameters - ---------- - sch : tvm.te.schedule.Schedule - The given scheduler to form the raw body - - args : list of Buffer or Tensor or Var - The argument lists to the function. - - name : str - The name of result function. - - binds : dict of :any:`Tensor` to :any:`Buffer`, optional - The binds information - - Returns - ------- - The body formed according to the given schedule - """ - # normalize schedule first - pass_ctx = PassContext.current() - sch = sch.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - - compact = schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) - - stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) - func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - return tvm.IRModule({name: func}) - -# bla = tvm.driver.lower(sch, fhdkslaje, ewakl fjelka) - def lower(sch, args, name="main", binds=None, simple_mode=False): return ffi.lower(sch, args, name, binds) -# def lower(sch, args, name="main", binds=None, simple_mode=False): -# return _ffi_api.lower(sch, args, name, binds) - -# def lower(sch, args, name="main", binds=None, simple_mode=False): - -# mod = self.__init_handler_by_constructor__(_ffi_api.tvm.driver.lower, sch, args, name, binds) - -# """Lowering step before build into target. - -# Parameters -# ---------- -# sch : tvm.te.schedule.Schedule -# The schedule to be built - -# args : list of Buffer or Tensor or Var -# The argument lists to the function. - -# name : str, optional -# The name of result function. - -# binds : dict of :any:`Tensor` to :any:`Buffer`, optional -# Dictionary that maps the Tensor to Buffer which specified the data layout -# requirement of the function. By default, a new compact buffer is created -# for each tensor in the argument. - -# simple_mode : bool, optional -# Whether only output simple and compact statement, this will skip -# LoopPartition, api wrapper generation and Unrolling. - -# Returns -# ------- -# m : IRModule or Stmt -# The result IRModule, if simple_mode=False -# Then the Stmt before make api is returned. -# """ - # config setup - # pass_ctx = PassContext.current() - # instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) - # disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) - # add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) - - # lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] - # lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] - # lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] - # lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - - # # Phase 0 - # if isinstance(sch, schedule.Schedule): - # mod = form_irmodule(sch, args, name, binds) - # else: - # mod = sch - - # pass_list = lower_phase0 - # # Phase 1 - # pass_list += [ - # tvm.tir.transform.InjectPrefetch(), - # tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), - # tvm.tir.transform.BF16Legalize(), - # tvm.tir.transform.NarrowDataType(32), - # tvm.tir.transform.Simplify(), - # ] - # pass_list += lower_phase1 - - # # Phase 2 - # if not simple_mode: - # pass_list += [(tvm.tir.transform.LoopPartition())] - - # pass_list += [ - # tvm.tir.transform.VectorizeLoop(not disable_vectorize), - # tvm.tir.transform.InjectVirtualThread(), - # tvm.tir.transform.InjectDoubleBuffer(), - # tvm.tir.transform.StorageRewrite(), - # tvm.tir.transform.UnrollLoop(), - # ] - # pass_list += lower_phase2 - - # # Phase 3 - # pass_list += [ - # tvm.tir.transform.Simplify(), - # tvm.tir.transform.RemoveNoOp(), - # ] - - # pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] - # pass_list += [tvm.tir.transform.HoistIfThenElse()] - # pass_list += lower_phase3 - - # # Instrument BoundCheckers - # if instrument_bound_checkers: - # pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - - # optimize = tvm.transform.Sequential(pass_list) - # mod = optimize(mod) - # return mod - def _build_for_device(input_mod, target, target_host): """Build the lowered functions for a device with the given compilation diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 182897ce264b..67098f8d3542 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -130,6 +130,7 @@ transform::Pass Filter(FCond fcond) { IRModule lower(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { + bool simple_mode = false; // TODO(@electriclilies): add as argument to IRModule Lower Array out_arg_list; auto pass_ctx = transform::PassContext::Current(); @@ -166,7 +167,11 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::LoopPartition()); + + if (!simple_mode) { + pass_list.push_back(tir::transform::LoopPartition()); + } + pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); @@ -176,6 +181,8 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); + // HoistIfThenElse + pass_list.push_back(tir::transform::HoistIfThenElse()); if (instrument_bound_checkers) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } @@ -185,16 +192,18 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } -TVM_REGISTER_GLOBAL("driver.lower").set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds) { - std::unordered_map c_binds; - // Check to make sure binds is not null before doing hte conversion; - if (binds->count) { // TODO: figure out why this is not compiling C++ sucks - for (auto kv : binds) { - c_binds.insert(std::pair(kv.first, kv.second)); - } - } - return lower(sch, args, name, c_binds); -}); +TVM_REGISTER_GLOBAL("driver.lower") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != NULL) { + for (auto kv : binds) { + c_binds.insert(std::pair(kv.first, kv.second)); + } + } + return lower(sch, args, name, c_binds); + }); std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, @@ -350,4 +359,3 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, } } // namespace tvm - From 8e3011624ca452e741157983ba71d76448600d27 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Fri, 21 May 2021 20:34:51 -0400 Subject: [PATCH 08/52] black'd _ffi_api.py --- python/tvm/driver/_ffi_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index eb078801008a..c423656d78f5 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -1,4 +1,3 @@ - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information From ca11f46b69a5128aa7490a6d8035b1c34b1942dd Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 24 May 2021 13:20:10 -0400 Subject: [PATCH 09/52] remove simple_mode arg from lower in build_module.py for now --- python/tvm/driver/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4264aaa974c9..00d1f7bd9e26 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -87,7 +87,7 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def lower(sch, args, name="main", binds=None, simple_mode=False): +def lower(sch, args, name="main", binds=None): return ffi.lower(sch, args, name, binds) From b1174788de82900838fc31936c4ccdb02912b5ec Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 24 May 2021 13:32:47 -0400 Subject: [PATCH 10/52] attempt add simple_mode arg in lower c++ --- python/tvm/driver/build_module.py | 4 ++-- src/driver/driver_api.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 00d1f7bd9e26..d4e4afd60b53 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -87,8 +87,8 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def lower(sch, args, name="main", binds=None): - return ffi.lower(sch, args, name, binds) +def lower(sch, args, name="main", binds=None, simple_mode=False): + return ffi.lower(sch, args, name, binds, simple_mode) def _build_for_device(input_mod, target, target_host): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 67098f8d3542..8091df60e577 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -129,8 +129,8 @@ transform::Pass Filter(FCond fcond) { } IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { - bool simple_mode = false; // TODO(@electriclilies): add as argument to IRModule Lower + const std::unordered_map& binds, bool simple_mode) { + // bool simple_mode = false; // TODO(@electriclilies): add as argument to IRModule Lower Array out_arg_list; auto pass_ctx = transform::PassContext::Current(); From 41640a185ec8cc976f4476f5c8e93d09ecb247cc Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 24 May 2021 18:18:53 -0400 Subject: [PATCH 11/52] lower now can take in schedule or IRModule --- python/tvm/driver/build_module.py | 4 +- src/driver/driver_api.cc | 70 +++++++++++++++++++------------ 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index d4e4afd60b53..415fbd94d450 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -87,8 +87,8 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def lower(sch, args, name="main", binds=None, simple_mode=False): - return ffi.lower(sch, args, name, binds, simple_mode) +def lower(sch, args= None, name="main", binds=None, simple_mode=False): + return ffi.lower(sch, args, name, binds) def _build_for_device(input_mod, target, target_host): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8091df60e577..06de9e5a634c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -128,36 +128,15 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { - // bool simple_mode = false; // TODO(@electriclilies): add as argument to IRModule Lower - Array out_arg_list; +IRModule lower(IRModule mod, const Array& args, const std::string& name, + const std::unordered_map& binds) { + bool simple_mode = false; // TODO(@electriclilies): add as argument to IRModule Lower auto pass_ctx = transform::PassContext::Current(); - sch = sch.normalize(); - - // Before TIR transformation. - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); - bool compact = te::VerifyCompactBuffer(stmt); - - Map out_binds; - GetBinds(args, compact, binds, &out_binds, &out_arg_list); - - // build the function - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - - auto mod = IRModule(Map({{GlobalVar(name), f}})); auto pass_list = Array(); // Phase 0 @@ -192,8 +171,37 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } +IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { + // Convert te schedule to IRModule + Array out_arg_list; + auto pass_ctx = transform::PassContext::Current(); + + sch = sch.normalize(); + + // Before TIR transformation. + auto bounds = te::InferBound(sch); + auto stmt = te::ScheduleOps(sch, bounds, false); + bool compact = te::VerifyCompactBuffer(stmt); + + Map out_binds; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // build the function + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(name), f}})); + return lower(mod, args, name, binds); +} + TVM_REGISTER_GLOBAL("driver.lower") - .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + .set_body_typed([](ObjectRef obj, const Array& args, const String& name, const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; @@ -202,9 +210,19 @@ TVM_REGISTER_GLOBAL("driver.lower") c_binds.insert(std::pair(kv.first, kv.second)); } } - return lower(sch, args, name, c_binds); + + if (const auto* p_mod = obj.as()) { + IRModule mod = GetRef(p_mod); + return lower(mod, args, name, c_binds); + } else if (const auto* p_sch = obj.as()) { + te::Schedule sch = GetRef(p_sch); + return lower(sch, args, name, c_binds); + } else { + ICHECK(false) << "driver.lower expected the first argument to be a te::Schedule or IRModule"; + } }); + std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { From 724db085b04a3e224dd0ea16e3d2e1980d6c06b0 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 24 May 2021 19:12:09 -0400 Subject: [PATCH 12/52] add simple_mode parameter in c++ --- include/tvm/driver/driver_api.h | 13 ++++++++++++- python/tvm/driver/build_module.py | 2 +- src/driver/driver_api.cc | 14 +++++++------- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 71a69a000944..78d089c0254a 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -51,7 +51,18 @@ namespace tvm { * \return The result module. */ TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds); + const std::unordered_map& binds, bool simple_mode = false); + +/*! + * \brief Build an IRModule given a module, args and binds + * \param sch The module to lower + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \return The result module. + */ +TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode = false); /*! * \brief Build a device and host module for a specific target from an IRModule. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 415fbd94d450..625f85765550 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -88,7 +88,7 @@ def get_binds(args, compact=False, binds=None): def lower(sch, args= None, name="main", binds=None, simple_mode=False): - return ffi.lower(sch, args, name, binds) + return ffi.lower(sch, args, name, binds, simple_mode) def _build_for_device(input_mod, target, target_host): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 06de9e5a634c..df046f0c29ca 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -129,8 +129,7 @@ transform::Pass Filter(FCond fcond) { } IRModule lower(IRModule mod, const Array& args, const std::string& name, - const std::unordered_map& binds) { - bool simple_mode = false; // TODO(@electriclilies): add as argument to IRModule Lower + const std::unordered_map& binds, bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); @@ -172,7 +171,7 @@ IRModule lower(IRModule mod, const Array& args, const std::string& n } IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { + const std::unordered_map& binds, bool simple_mode) { // Convert te schedule to IRModule Array out_arg_list; auto pass_ctx = transform::PassContext::Current(); @@ -197,12 +196,12 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - return lower(mod, args, name, binds); + return lower(mod, args, name, binds, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower") .set_body_typed([](ObjectRef obj, const Array& args, const String& name, - const Map& binds) { + const Map& binds, bool simple_mode) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; if (binds.get() != NULL) { @@ -213,12 +212,13 @@ TVM_REGISTER_GLOBAL("driver.lower") if (const auto* p_mod = obj.as()) { IRModule mod = GetRef(p_mod); - return lower(mod, args, name, c_binds); + return lower(mod, args, name, c_binds, simple_mode); } else if (const auto* p_sch = obj.as()) { te::Schedule sch = GetRef(p_sch); - return lower(sch, args, name, c_binds); + return lower(sch, args, name, c_binds, simple_mode); } else { ICHECK(false) << "driver.lower expected the first argument to be a te::Schedule or IRModule"; + throw; } }); From 47ea819702c77706d3fa27c9b452badbf03dbde4 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 24 May 2021 19:21:53 -0400 Subject: [PATCH 13/52] reformat for lint --- include/tvm/driver/driver_api.h | 6 ++++-- python/tvm/driver/build_module.py | 2 +- src/driver/driver_api.cc | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 78d089c0254a..957875972ab7 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -51,7 +51,8 @@ namespace tvm { * \return The result module. */ TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode = false); + const std::unordered_map& binds, + bool simple_mode = false); /*! * \brief Build an IRModule given a module, args and binds @@ -62,7 +63,8 @@ TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const st * \return The result module. */ TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode = false); + const std::unordered_map& binds, + bool simple_mode = false); /*! * \brief Build a device and host module for a specific target from an IRModule. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 625f85765550..a7e698dd7d35 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -87,7 +87,7 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def lower(sch, args= None, name="main", binds=None, simple_mode=False): +def lower(sch, args=None, name="main", binds=None, simple_mode=False): return ffi.lower(sch, args, name, binds, simple_mode) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index df046f0c29ca..2b06eba2bd0e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -217,12 +217,12 @@ TVM_REGISTER_GLOBAL("driver.lower") te::Schedule sch = GetRef(p_sch); return lower(sch, args, name, c_binds, simple_mode); } else { - ICHECK(false) << "driver.lower expected the first argument to be a te::Schedule or IRModule"; + ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule or " + << "IRModule"; throw; } }); - std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { From c01ca1e1909d4c5172a46db4d2ab349ca2028610 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Mon, 24 May 2021 19:36:35 -0400 Subject: [PATCH 14/52] include header details for lint --- include/tvm/driver/driver_api.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 957875972ab7..e0aa455fe0af 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -48,6 +48,7 @@ namespace tvm { * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, @@ -56,10 +57,11 @@ TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const st /*! * \brief Build an IRModule given a module, args and binds - * \param sch The module to lower + * \param mod The IRmodule to lower * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::string& name, From 49d6c7ffed8303b047694c095a47e47e716980ce Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Tue, 25 May 2021 18:07:12 -0400 Subject: [PATCH 15/52] ast lhs and rhs not matching, refactoring driver_api.cc --- include/tvm/driver/driver_api.h | 14 +++++++++ python/tvm/driver/build_module.py | 37 +++++++++++++++++++++-- src/driver/driver_api.cc | 27 +++++++++++++++-- tests/python/unittest/test_lower_build.py | 1 + 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index e0aa455fe0af..42e340d097b0 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -68,6 +69,19 @@ TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::s const std::unordered_map& binds, bool simple_mode = false); +/*! + * \brief Build an IRModule given a module, args and binds + * \param func The PrimFunc to lower + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a7e698dd7d35..9b6d064fa270 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -86,9 +86,42 @@ def get_binds(args, compact=False, binds=None): raise ValueError("args must be Tensor, Buffer or Var") return binds, arg_list +def lower( + inputs: Union[schedule.Schedule, PrimFunc, IRModule], + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, + simple_mode: bool = False, +) -> IRModule: + """Lowering step before build into target. + + Parameters + ---------- + input : Union[schedule.Schedule, PrimFunc, IRModule] + The TE schedule or TensorIR PrimFunc/IRModule to be built + + args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] + The argument lists to the function for TE schedule. + It should be None if we want to lower TensorIR. + + name : str + The name of result function. -def lower(sch, args=None, name="main", binds=None, simple_mode=False): - return ffi.lower(sch, args, name, binds, simple_mode) + binds : Optional[Mapping[tensor.Tensor, Buffer]] + Dictionary that maps the Tensor to Buffer which specified the data layout + requirement of the function. By default, a new compact buffer is created + for each tensor in the argument. + + simple_mode : bool + Whether only output simple and compact statement, this will skip + LoopPartition, api wrapper generation and Unrolling. + + Returns + ------- + m : IRModule + The result IRModule + """ + return ffi.lower(inputs, args, name, binds, simple_mode) def _build_for_device(input_mod, target, target_host): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2b06eba2bd0e..9e1f4158a4e2 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -128,6 +128,11 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } +Array LegacyTEPassList() { + auto pass_list = Array(); + +} + IRModule lower(IRModule mod, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); @@ -199,6 +204,21 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return lower(mod, args, name, binds, simple_mode); } +IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + auto pass_ctx = transform::PassContext::Current(); + auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(name), f}})); + return lower(mod, args, name, binds, simple_mode); + +} + TVM_REGISTER_GLOBAL("driver.lower") .set_body_typed([](ObjectRef obj, const Array& args, const String& name, const Map& binds, bool simple_mode) { @@ -216,9 +236,12 @@ TVM_REGISTER_GLOBAL("driver.lower") } else if (const auto* p_sch = obj.as()) { te::Schedule sch = GetRef(p_sch); return lower(sch, args, name, c_binds, simple_mode); + } else if (const auto* p_func = obj.as()) { + tvm::tir::PrimFunc func = GetRef(p_func); + return lower(func, args, name, c_binds, simple_mode); } else { - ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule or " - << "IRModule"; + ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule, " + << "PrimFunc, or IRModule"; throw; } }); diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 4505a7bed244..3742df7f3b8a 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -82,6 +82,7 @@ def test_lower_build_te_schedule(): def test_lower_build_tir_func(): # check lowering + print("Type of matmul: ", type(matmul)) ir_mod = tvm.lower(matmul) tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) # check building From 5e43b18a7cd64d6a3e9c838203df232d3b90203f Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 25 May 2021 18:22:58 -0700 Subject: [PATCH 16/52] Added user-defined passes, still failing some tests because python lets more types in for args --- include/tvm/driver/driver_api.h | 7 ++- src/driver/driver_api.cc | 104 ++++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 21 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 42e340d097b0..ce35cd068120 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -78,9 +78,10 @@ TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::s * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std::string& name, - const std::unordered_map& binds, - bool simple_mode = false); +TVM_DLL IRModule lower(tvm::tir::PrimFunc func, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); /*! * \brief Build a device and host module for a specific target from an IRModule. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9e1f4158a4e2..d105bfb76bfa 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -128,29 +128,75 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -Array LegacyTEPassList() { - auto pass_list = Array(); - -} - -IRModule lower(IRModule mod, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { +Array CreatePassList(bool simple_mode, bool legacy_te_pass) { auto pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - auto pass_list = Array(); + // Get any user-added passes + auto add_lower_pass = + pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) + .value(); + + auto user_lower_phase0 = Array(); + auto user_lower_phase1 = Array(); + auto user_lower_phase2 = Array(); + auto user_lower_phase3 = Array(); + + // phase pasees is of the form + // [[phase_number, pass], [phase_number, pass]... ] + for (auto phase_pass : add_lower_pass) { + auto phase_num = phase_pass[0].as(); + ICHECK(phase_num) + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + int phase_num_val = phase_num->value; + + CHECK_GT(phase_num_val, 0); + + // TODO(electriclilies): is there a cleaner way to do this? + auto pass_node = phase_pass[1].as(); + auto pass = GetRef(pass_node); + // Copy the pass into the correct phase + if (phase_num_val == 0) { + user_lower_phase0.push_back(pass); + } else if (phase_num_val == 1) { + user_lower_phase1.push_back(pass); + } else if (phase_num_val == 2) { + user_lower_phase2.push_back(pass); + } else if (phase_num_val >= 3) { + user_lower_phase3.push_back(pass); + } + } - // Phase 0 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - // Phase 1 + // Construct the pass list, inserting the user provided passes at the end of the phase + // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase. + // The code is inconsistent with what passes are in which phase as well. For now I have coped the + // python behavior exactly. + + // PHASE 0 + auto pass_list = user_lower_phase0; + + // PHASE 1 + if (legacy_te_pass) { + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + } else { + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::FlattenBuffer()); + } pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); + // Add user-defined phase-1 passes + pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); + + // PHASE 2 if (!simple_mode) { pass_list.push_back(tir::transform::LoopPartition()); } @@ -160,21 +206,38 @@ IRModule lower(IRModule mod, const Array& args, const std::string& n pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::UnrollLoop()); - // Phase 2 + + // Add user-defined phase-2 passes + pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); + + // PHASE 3 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); // HoistIfThenElse pass_list.push_back(tir::transform::HoistIfThenElse()); + + // Add user-defined phase-3 passes + pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); + if (instrument_bound_checkers) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } - // run + return pass_list; +} + +IRModule LowerWithPassList(IRModule mod, Array pass_list) { auto optimize = transform::Sequential(pass_list); mod = optimize(std::move(mod)); return mod; } +IRModule lower(IRModule mod, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + auto pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(mod, pass_list); +} + IRModule lower(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { // Convert te schedule to IRModule @@ -201,22 +264,27 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - return lower(mod, args, name, binds, simple_mode); + + // Get the legacy TE pass list + auto pass_list = CreatePassList(simple_mode, true); + return LowerWithPassList(mod, pass_list); } IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); - + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - return lower(mod, args, name, binds, simple_mode); - + + // Get the pass list + auto pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower") From 59534cbe6eab4e961837a048989aac73d744ed5c Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 15:57:43 -0700 Subject: [PATCH 17/52] tests are green --- include/tvm/driver/driver_api.h | 17 +++- python/tvm/driver/build_module.py | 46 +++++++++ src/driver/driver_api.cc | 97 ++++++++++++++++--- .../test_tir_transform_bf16_legalize.py | 8 +- .../unittest/test_tir_transform_hoist_if.py | 4 +- 5 files changed, 152 insertions(+), 20 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index ce35cd068120..020f13818bf2 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -52,10 +52,14 @@ namespace tvm { * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, +TVM_DLL IRModule legacyLower(IRModule mod, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); + +TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); /*! * \brief Build an IRModule given a module, args and binds * \param mod The IRmodule to lower @@ -83,6 +87,17 @@ TVM_DLL IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std::unordered_map& binds, bool simple_mode = false); +/*! + * \brief Create an IRModule out of a Schedule + * \param sch The schedule + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. + * \return The result module. + */ +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds); /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 9b6d064fa270..d6f99f2ae6e2 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -86,6 +86,49 @@ def get_binds(args, compact=False, binds=None): raise ValueError("args must be Tensor, Buffer or Var") return binds, arg_list + +def schedule_to_module( + sch: schedule.Schedule, + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: + """According to the given schedule, form a function. + Parameters + ---------- + sch : tvm.te.schedule.Schedule + The given scheduler to form the raw body + args : list of Buffer or Tensor or Var + The argument lists to the function. + name : str + The name of result function. + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + The binds information + Returns + ------- + The body formed according to the given schedule + """ + # normalize schedule first + """ + pass_ctx = PassContext.current() + sch = sch.normalize() + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + + compact = schedule.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, binds) + + stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) + func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", name) + + if pass_ctx.config.get("tir.noalias", True): + func = func.with_attr("tir.noalias", True) + return tvm.IRModule({name: func}) + """ + return ffi.schedule_to_module(sch, args, name, binds) + def lower( inputs: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -121,6 +164,9 @@ def lower( m : IRModule The result IRModule """ + if isinstance(inputs, schedule.Schedule): + mod = schedule_to_module(inputs, args, name, binds) + return ffi.legacy_lower(mod, None, name, None, simple_mode) return ffi.lower(inputs, args, name, binds, simple_mode) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d105bfb76bfa..8140ccd4e789 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -93,6 +93,8 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } + +// comment to try to remove this void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { @@ -109,6 +111,30 @@ void GetBinds(const Array& args, bool compact, } } + +void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list) { + *out_binds = binds; + + for (const ObjectRef& x : args) { + if (const auto* tensor_node = x.as()) { + auto x_ref = GetRef(tensor_node); + if (out_binds->find(x_ref) == out_binds->end()) { + auto buf = BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + out_binds->Set(x_ref, buf); + out_arg_list->push_back(buf); + } else { + out_arg_list->push_back((*out_binds)[x_ref]); + } + } else if (x.as() || x.as()) { + out_arg_list->push_back(x); + } else { + ICHECK(false) << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var"; + } + } +} + transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -232,14 +258,8 @@ IRModule LowerWithPassList(IRModule mod, Array pass_list) return mod; } -IRModule lower(IRModule mod, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { - auto pass_list = CreatePassList(simple_mode, false); - return LowerWithPassList(mod, pass_list); -} - -IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { // Convert te schedule to IRModule Array out_arg_list; auto pass_ctx = transform::PassContext::Current(); @@ -255,6 +275,9 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin GetBinds(args, compact, binds, &out_binds, &out_arg_list); // build the function + // At this point binds is only te::Tensors + + stmt = te::SchedulePostProcRewriteForTensorCore(stmt, sch, binds); // TODO(electriclilies): Should this be in here? Was in python but not C++ version. tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); @@ -263,8 +286,47 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } - IRModule mod = IRModule(Map({{GlobalVar(name), f}})); + return IRModule(Map({{GlobalVar(name), f}})); +} + +TVM_REGISTER_GLOBAL("driver.schedule_to_module") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != NULL) { + for (auto kv : binds) { + c_binds.insert(std::pair(kv.first, kv.second)); + } + } + IRModule mod = ScheduleToModule(sch, args, name, c_binds); + return mod; + }); + + +IRModule lower(IRModule mod, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + auto pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(mod, pass_list); +} +IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + + Array ref_arr; + for (auto x : args) { + ref_arr.push_back(x); + } + IRModule mod = ScheduleToModule(sch, ref_arr, name, binds); + // Get the legacy TE pass list + auto pass_list = CreatePassList(simple_mode, true); + return LowerWithPassList(mod, pass_list); +} + + + +IRModule legacyLower(IRModule mod, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { // Get the legacy TE pass list auto pass_list = CreatePassList(simple_mode, true); return LowerWithPassList(mod, pass_list); @@ -297,13 +359,9 @@ TVM_REGISTER_GLOBAL("driver.lower") c_binds.insert(std::pair(kv.first, kv.second)); } } - if (const auto* p_mod = obj.as()) { IRModule mod = GetRef(p_mod); return lower(mod, args, name, c_binds, simple_mode); - } else if (const auto* p_sch = obj.as()) { - te::Schedule sch = GetRef(p_sch); - return lower(sch, args, name, c_binds, simple_mode); } else if (const auto* p_func = obj.as()) { tvm::tir::PrimFunc func = GetRef(p_func); return lower(func, args, name, c_binds, simple_mode); @@ -314,6 +372,19 @@ TVM_REGISTER_GLOBAL("driver.lower") } }); +TVM_REGISTER_GLOBAL("driver.legacy_lower") + .set_body_typed([](IRModule mod, const Array& args, const String& name, + const Map& binds, bool simple_mode) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != NULL) { + for (auto kv : binds) { + c_binds.insert(std::pair(kv.first, kv.second)); + } + } + return legacyLower(mod, args, name, c_binds, simple_mode); + }); + std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index c5163a8457af..1e3c8061e029 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -20,7 +20,7 @@ def lower_stmt(sche, params, passfunc): - func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(sche, params, "main", None)["main"] func = passfunc()(tvm.IRModule.from_expr(func))["main"] stmt = func.body return stmt @@ -42,7 +42,7 @@ def get_promoted(op): lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], "float")), "bfloat16"), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body def test_promoted(op): @@ -111,7 +111,7 @@ def get_target(): ), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body tvm.ir.assert_structural_equal(get_eliminated(), get_target()) @@ -151,7 +151,7 @@ def check(fcompute_before, fcompute_after): b = te.placeholder((100,), dtype="uint16", name="B") c = te.compute((100,), fcompute_after(a, b), name="C") s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] tvm.ir.assert_structural_equal(stmt, func.body) def orig1(a, b): diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 7d02e4f12c1d..252a187dbdc5 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -522,7 +522,7 @@ def test_hoisting_block_scope_1(): s[B.op].bind(xi, te.thread_axis("threadIdx.y")) s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) - func = tvm.driver.build_module.form_irmodule(s, [A, B], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) @@ -622,7 +622,7 @@ def test_hoisting_block_scope_4(): s[C].pragma(xo2, "parallel_stride_pattern") s[C].pragma(xo2, "parallel_barrier_when_finish") s[C].vectorize(xi) - func = tvm.driver.build_module.form_irmodule(s, [A, B, C], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) From 5cf0f78b5b5360e4116e36ad8bb6d2f44c61e92b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 19:09:08 -0700 Subject: [PATCH 18/52] Split lower api into 3 --- include/tvm/driver/driver_api.h | 6 ++--- python/tvm/driver/build_module.py | 5 ++++ src/driver/driver_api.cc | 38 +++++++++++++++-------------- src/relay/backend/compile_engine.cc | 4 +-- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 020f13818bf2..32ed9dfed27f 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -57,7 +57,7 @@ TVM_DLL IRModule legacyLower(IRModule mod, const Array& args, const bool simple_mode = false); -TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, +TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); /*! @@ -69,7 +69,7 @@ TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const st * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::string& name, +TVM_DLL IRModule lower_module(IRModule mod, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); @@ -82,7 +82,7 @@ TVM_DLL IRModule lower(IRModule mod, const Array& args, const std::s * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(tvm::tir::PrimFunc func, const Array& args, +TVM_DLL IRModule lower_primfunc(tvm::tir::PrimFunc func, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index d6f99f2ae6e2..89c1dbfb4895 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -164,9 +164,14 @@ def lower( m : IRModule The result IRModule """ + if isinstance(inputs, IRModule): + return ffi.lower_module(inputs) + if isinstance(inputs, PrimFunc): + return ffi.lower_primfunc(inputs) if isinstance(inputs, schedule.Schedule): mod = schedule_to_module(inputs, args, name, binds) return ffi.legacy_lower(mod, None, name, None, simple_mode) + return ffi.lower(inputs, args, name, binds, simple_mode) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8140ccd4e789..cfd4c9d25c84 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -304,13 +304,13 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") }); -IRModule lower(IRModule mod, const Array& args, const std::string& name, +IRModule lower_module(IRModule mod, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { auto pass_list = CreatePassList(simple_mode, false); return LowerWithPassList(mod, pass_list); } -IRModule lower(te::Schedule sch, const Array& args, const std::string& name, +IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { Array ref_arr; @@ -323,8 +323,6 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return LowerWithPassList(mod, pass_list); } - - IRModule legacyLower(IRModule mod, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { // Get the legacy TE pass list @@ -332,7 +330,7 @@ IRModule legacyLower(IRModule mod, const Array& args, const std::str return LowerWithPassList(mod, pass_list); } -IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std::string& name, +IRModule lower_primfunc(tvm::tir::PrimFunc func, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); @@ -349,8 +347,8 @@ IRModule lower(tvm::tir::PrimFunc func, const Array& args, const std return LowerWithPassList(mod, pass_list); } -TVM_REGISTER_GLOBAL("driver.lower") - .set_body_typed([](ObjectRef obj, const Array& args, const String& name, +TVM_REGISTER_GLOBAL("driver.lower_schedule") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds, bool simple_mode) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; @@ -359,17 +357,21 @@ TVM_REGISTER_GLOBAL("driver.lower") c_binds.insert(std::pair(kv.first, kv.second)); } } - if (const auto* p_mod = obj.as()) { - IRModule mod = GetRef(p_mod); - return lower(mod, args, name, c_binds, simple_mode); - } else if (const auto* p_func = obj.as()) { - tvm::tir::PrimFunc func = GetRef(p_func); - return lower(func, args, name, c_binds, simple_mode); - } else { - ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule, " - << "PrimFunc, or IRModule"; - throw; - } + return lower_schedule(sch, args, name, c_binds, simple_mode); + }); + +TVM_REGISTER_GLOBAL("driver.lower_module") + .set_body_typed([](IRModule mod, const Array& args, const String& name, + const Map& binds, bool simple_mode) { + std::unordered_map c_binds; + return lower_module(mod, args, name, c_binds, simple_mode); + }); + +TVM_REGISTER_GLOBAL("driver.lower_primfunc") + .set_body_typed([](te::PrimFunc func, const Array& args, const String& name, + const Map& binds, bool simple_mode) { + std::unordered_map c_binds; + return lower_primfunc(func, args, name, c_binds, simple_mode); }); TVM_REGISTER_GLOBAL("driver.legacy_lower") diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 5e3b66b3ae15..269713cc6be3 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::lower_schedule(cfunc->schedule, all_args, cache_node->func_name, binds); } value->cached_func = CachedFunc(cache_node); return value; @@ -807,7 +807,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::lower_schedule(spair.first, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; } From b10d6ad484f1891f7316c2b56f59aecdadd754e9 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 19:28:09 -0700 Subject: [PATCH 19/52] got rid of legacy lower --- include/tvm/driver/driver_api.h | 6 +++--- python/tvm/driver/build_module.py | 5 +---- src/driver/driver_api.cc | 27 ++++++--------------------- 3 files changed, 10 insertions(+), 28 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 32ed9dfed27f..6b3c81207571 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -52,12 +52,12 @@ namespace tvm { * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule legacyLower(IRModule mod, const Array& args, const std::string& name, + +TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); - -TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, +TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); /*! diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 89c1dbfb4895..588b5b332eb6 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -169,10 +169,7 @@ def lower( if isinstance(inputs, PrimFunc): return ffi.lower_primfunc(inputs) if isinstance(inputs, schedule.Schedule): - mod = schedule_to_module(inputs, args, name, binds) - return ffi.legacy_lower(mod, None, name, None, simple_mode) - - return ffi.lower(inputs, args, name, binds, simple_mode) + return ffi.lower_schedule(inputs, args, name, binds, simple_mode) def _build_for_device(input_mod, target, target_host): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cfd4c9d25c84..dafd4e51abbf 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -313,19 +313,17 @@ IRModule lower_module(IRModule mod, const Array& args, const std::st IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { - Array ref_arr; + Array ref_args; for (auto x : args) { - ref_arr.push_back(x); + ref_args.push_back(x); } - IRModule mod = ScheduleToModule(sch, ref_arr, name, binds); - // Get the legacy TE pass list - auto pass_list = CreatePassList(simple_mode, true); - return LowerWithPassList(mod, pass_list); + return lower_schedule(sch, ref_args, name, binds); } -IRModule legacyLower(IRModule mod, const Array& args, const std::string& name, +IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { // Get the legacy TE pass list + IRModule mod = ScheduleToModule(sch, args, name, binds); auto pass_list = CreatePassList(simple_mode, true); return LowerWithPassList(mod, pass_list); } @@ -348,7 +346,7 @@ IRModule lower_primfunc(tvm::tir::PrimFunc func, const Array& args, } TVM_REGISTER_GLOBAL("driver.lower_schedule") - .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds, bool simple_mode) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; @@ -374,19 +372,6 @@ TVM_REGISTER_GLOBAL("driver.lower_primfunc") return lower_primfunc(func, args, name, c_binds, simple_mode); }); -TVM_REGISTER_GLOBAL("driver.legacy_lower") - .set_body_typed([](IRModule mod, const Array& args, const String& name, - const Map& binds, bool simple_mode) { - std::unordered_map c_binds; - // Check to make sure binds is not null before doing the conversion; - if (binds.get() != NULL) { - for (auto kv : binds) { - c_binds.insert(std::pair(kv.first, kv.second)); - } - } - return legacyLower(mod, args, name, c_binds, simple_mode); - }); - std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { From 5d5c0683a9c34e733ebf571d5dad3ad02359e2f1 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 19:49:26 -0700 Subject: [PATCH 20/52] renamed lower funcs --- include/tvm/driver/driver_api.h | 8 ++-- src/driver/driver_api.cc | 66 ++++++++++++++--------------- src/relay/backend/compile_engine.cc | 4 +- 3 files changed, 37 insertions(+), 41 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 6b3c81207571..d3264ee7ba1d 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -53,11 +53,11 @@ namespace tvm { * \return The result module. */ -TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); -TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); /*! @@ -69,7 +69,7 @@ TVM_DLL IRModule lower_schedule(te::Schedule sch, const Array& args, * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower_module(IRModule mod, const Array& args, const std::string& name, +TVM_DLL IRModule LowerModule(IRModule mod, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); @@ -82,7 +82,7 @@ TVM_DLL IRModule lower_module(IRModule mod, const Array& args, const * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower_primfunc(tvm::tir::PrimFunc func, const Array& args, +TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode = false); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index dafd4e51abbf..8d5b9c0c4cc3 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -304,32 +304,18 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") }); -IRModule lower_module(IRModule mod, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { +IRModule LowerModule(IRModule mod, bool simple_mode) { auto pass_list = CreatePassList(simple_mode, false); return LowerWithPassList(mod, pass_list); } -IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { +TVM_REGISTER_GLOBAL("driver.lower_module") + .set_body_typed([](IRModule mod, bool simple_mode) { + return LowerModule(mod, simple_mode); + }); - Array ref_args; - for (auto x : args) { - ref_args.push_back(x); - } - return lower_schedule(sch, ref_args, name, binds); -} -IRModule lower_schedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { - // Get the legacy TE pass list - IRModule mod = ScheduleToModule(sch, args, name, binds); - auto pass_list = CreatePassList(simple_mode, true); - return LowerWithPassList(mod, pass_list); -} - -IRModule lower_primfunc(tvm::tir::PrimFunc func, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { +IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); @@ -345,6 +331,30 @@ IRModule lower_primfunc(tvm::tir::PrimFunc func, const Array& args, return LowerWithPassList(mod, pass_list); } +TVM_REGISTER_GLOBAL("driver.lower_primfunc") + .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { + return LowerPrimFunc(func, name, simple_mode); + }); + + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + + Array ref_args; + for (auto x : args) { + ref_args.push_back(x); + } + return LowerSchedule(sch, ref_args, name, binds); +} + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + // Get the legacy TE pass list + IRModule mod = ScheduleToModule(sch, args, name, binds); + auto pass_list = CreatePassList(simple_mode, true); + return LowerWithPassList(mod, pass_list); +} + TVM_REGISTER_GLOBAL("driver.lower_schedule") .set_body_typed([](te::Schedule sch, const Array& args, const String& name, const Map& binds, bool simple_mode) { @@ -355,21 +365,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert(std::pair(kv.first, kv.second)); } } - return lower_schedule(sch, args, name, c_binds, simple_mode); - }); - -TVM_REGISTER_GLOBAL("driver.lower_module") - .set_body_typed([](IRModule mod, const Array& args, const String& name, - const Map& binds, bool simple_mode) { - std::unordered_map c_binds; - return lower_module(mod, args, name, c_binds, simple_mode); - }); - -TVM_REGISTER_GLOBAL("driver.lower_primfunc") - .set_body_typed([](te::PrimFunc func, const Array& args, const String& name, - const Map& binds, bool simple_mode) { - std::unordered_map c_binds; - return lower_primfunc(func, args, name, c_binds, simple_mode); + return LowerSchedule(sch, args, name, c_binds, simple_mode); }); std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 269713cc6be3..7dcace457c97 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower_schedule(cfunc->schedule, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); } value->cached_func = CachedFunc(cache_node); return value; @@ -807,7 +807,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower_schedule(spair.first, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; } From f2596c50bf295315fa9781820deae5ec7615cc3e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 20:00:45 -0700 Subject: [PATCH 21/52] remove python get_binds and rename flags --- python/tvm/driver/build_module.py | 67 ------------------------------- src/driver/driver_api.cc | 34 ++++++++-------- 2 files changed, 17 insertions(+), 84 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 588b5b332eb6..4a12f3813f05 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -39,54 +39,6 @@ from . import _ffi_api as ffi - -def get_binds(args, compact=False, binds=None): - """Internal function to get binds and arg_list given arguments. - - Parameters - ---------- - args : list of Buffer or Tensor or Var - The argument lists to the function. - - compact : bool - If the statement has already bound to a compact buffer. - - binds : dict of :any:`Tensor` to :any:`Buffer`, optional - Dictionary that maps the Tensor to Buffer which specified the data layout - requirement of the function. By default, a new compact buffer is created - for each tensor in the argument. - - Returns - ------- - binds: dict - The bind specification - - arg_list: list - The list of symbolic buffers of arguments. - """ - binds = {} if binds is None else binds.copy() - arg_list = [] - for x in args: - if isinstance(x, tensor.Tensor): - any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) - buffer_type = "auto_broadcast" if any_dim and not compact else "" - if x not in binds: - buf = tvm.tir.decl_buffer( - x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type - ) - binds[x] = buf - arg_list.append(buf) - else: - arg_list.append(binds[x]) - elif isinstance(x, schedule.Buffer): - arg_list.append(x) - elif isinstance(x, tvm.tir.Var): - arg_list.append(x) - else: - raise ValueError("args must be Tensor, Buffer or Var") - return binds, arg_list - - def schedule_to_module( sch: schedule.Schedule, args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -108,25 +60,6 @@ def schedule_to_module( ------- The body formed according to the given schedule """ - # normalize schedule first - """ - pass_ctx = PassContext.current() - sch = sch.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - - compact = schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) - - stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) - func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - return tvm.IRModule({name: func}) - """ return ffi.schedule_to_module(sch, args, name, binds) def lower( diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8d5b9c0c4cc3..8b1e5c8cc97a 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -111,7 +111,6 @@ void GetBinds(const Array& args, bool compact, } } - void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { @@ -135,6 +134,8 @@ void GetBinds(const Array& args, bool compact, } } + + transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -154,7 +155,7 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -Array CreatePassList(bool simple_mode, bool legacy_te_pass) { +Array CreatePassList(bool enable_loop_partition, bool for_te_schedule) { auto pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); @@ -181,7 +182,6 @@ Array CreatePassList(bool simple_mode, bool legacy_te_pass CHECK_GT(phase_num_val, 0); - // TODO(electriclilies): is there a cleaner way to do this? auto pass_node = phase_pass[1].as(); auto pass = GetRef(pass_node); // Copy the pass into the correct phase @@ -223,7 +223,7 @@ Array CreatePassList(bool simple_mode, bool legacy_te_pass pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); // PHASE 2 - if (!simple_mode) { + if (!enable_loop_partition) { pass_list.push_back(tir::transform::LoopPartition()); } @@ -304,18 +304,18 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") }); -IRModule LowerModule(IRModule mod, bool simple_mode) { - auto pass_list = CreatePassList(simple_mode, false); +IRModule LowerModule(IRModule mod, bool enable_loop_partition) { + auto pass_list = CreatePassList(enable_loop_partition, false); return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower_module") - .set_body_typed([](IRModule mod, bool simple_mode) { - return LowerModule(mod, simple_mode); + .set_body_typed([](IRModule mod, bool enable_loop_partition) { + return LowerModule(mod, enable_loop_partition); }); -IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode) { +IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool enable_loop_partition) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); @@ -327,18 +327,18 @@ IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool si IRModule mod = IRModule(Map({{GlobalVar(name), f}})); // Get the pass list - auto pass_list = CreatePassList(simple_mode, false); + auto pass_list = CreatePassList(enable_loop_partition, false); return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower_primfunc") - .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { - return LowerPrimFunc(func, name, simple_mode); + .set_body_typed([](te::PrimFunc func, const String& name, bool enable_loop_partition) { + return LowerPrimFunc(func, name, enable_loop_partition); }); IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { + const std::unordered_map& binds, bool enable_loop_partition) { Array ref_args; for (auto x : args) { @@ -348,16 +348,16 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const st } IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { + const std::unordered_map& binds, bool enable_loop_partition) { // Get the legacy TE pass list IRModule mod = ScheduleToModule(sch, args, name, binds); - auto pass_list = CreatePassList(simple_mode, true); + auto pass_list = CreatePassList(enable_loop_partition, true); return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower_schedule") .set_body_typed([](te::Schedule sch, const Array& args, const String& name, - const Map& binds, bool simple_mode) { + const Map& binds, bool enable_loop_partition) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; if (binds.get() != NULL) { @@ -365,7 +365,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert(std::pair(kv.first, kv.second)); } } - return LowerSchedule(sch, args, name, c_binds, simple_mode); + return LowerSchedule(sch, args, name, c_binds, enable_loop_partition); }); std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, From d6be36e5bb11d83d5c6f9b7eb51dacf623443a4c Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 20:01:48 -0700 Subject: [PATCH 22/52] fix typo --- src/driver/driver_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8b1e5c8cc97a..7dc38ac05da4 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -205,7 +205,7 @@ Array CreatePassList(bool enable_loop_partition, bool for_ auto pass_list = user_lower_phase0; // PHASE 1 - if (legacy_te_pass) { + if (for_te_schedule) { pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); } else { From 06532d350cce795d6430065ece46d0fa0b0ecaec Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 20:26:34 -0700 Subject: [PATCH 23/52] clean up doc and formatting --- include/tvm/driver/driver_api.h | 59 +++++++++++++++-------------- python/tvm/driver/build_module.py | 19 ++++++---- src/driver/driver_api.cc | 39 ++++++++++--------- src/relay/backend/compile_engine.cc | 3 +- 4 files changed, 64 insertions(+), 56 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index d3264ee7ba1d..9aef10d67afa 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -43,49 +43,53 @@ #include namespace tvm { + /*! - * \brief Build an IRModule given a schedule, args and binds - * \param sch The schedule to lower. - * \param args The arguments to the function. - * \param name The name of the lowered function. - * \param binds Buffer assignments. - * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. + * \brief Build an IRModule given a module, args and binds + * \param mod The IRmodule to lower + * \param enable_loop_partition Enables the loop partition pass. Defaults to true. * \return The result module. */ +TVM_DLL IRModule LowerModule(IRModule mod, bool enable_loop_partition = true); -TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, - bool simple_mode = false); - -TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, - bool simple_mode = false); /*! * \brief Build an IRModule given a module, args and binds - * \param mod The IRmodule to lower + * \param func The PrimFunc to lower + * \param name The name of the lowered function. + * \param enable_loop_partition Enables the loop partition pass. Defaults to true. + * \return The result module. + */ +TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, + bool enable_loop_partition = true); + +/*! + * \brief Build an IRModule given a schedule, args and binds + * \param sch The schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. - * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. + * \param enable_loop_partition Enables the loop partition pass. Defaults to true. * \return The result module. */ -TVM_DLL IRModule LowerModule(IRModule mod, const Array& args, const std::string& name, - const std::unordered_map& binds, - bool simple_mode = false); + +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool enable_loop_partition = true); /*! - * \brief Build an IRModule given a module, args and binds - * \param func The PrimFunc to lower - * \param args The arguments to the function. + * \brief Build an IRModule given a schedule, args and binds + * \param sch The schedule to lower. + * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars) * \param name The name of the lowered function. * \param binds Buffer assignments. - * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. + * \param enable_loop_partition Enables the loop partition pass. Defaults to true. * \return The result module. */ -TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const Array& args, - const std::string& name, - const std::unordered_map& binds, - bool simple_mode = false); +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool enable_loop_partition = true); /*! * \brief Create an IRModule out of a Schedule @@ -93,11 +97,10 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const Array& * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. - * \param simple_mode Skips the LoopPartition pass if true. Defaults to false. * \return The result module. */ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds); + const std::unordered_map& binds); /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4a12f3813f05..e0a1ac02ee4a 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -39,6 +39,7 @@ from . import _ffi_api as ffi + def schedule_to_module( sch: schedule.Schedule, args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -62,8 +63,9 @@ def schedule_to_module( """ return ffi.schedule_to_module(sch, args, name, binds) + def lower( - inputs: Union[schedule.Schedule, PrimFunc, IRModule], + input: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -97,12 +99,15 @@ def lower( m : IRModule The result IRModule """ - if isinstance(inputs, IRModule): - return ffi.lower_module(inputs) - if isinstance(inputs, PrimFunc): - return ffi.lower_primfunc(inputs) - if isinstance(inputs, schedule.Schedule): - return ffi.lower_schedule(inputs, args, name, binds, simple_mode) + if isinstance(input, IRModule): + return ffi.lower_module(input) + if isinstance(input, PrimFunc): + return ffi.lower_primfunc(input) + if isinstance(input, schedule.Schedule): + return ffi.lower_schedule(input, args, name, binds, simple_mode) + raise ValueError( + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) + ) def _build_for_device(input_mod, target, target_host): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7dc38ac05da4..ac887ed1069e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -93,7 +93,6 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } - // comment to try to remove this void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, @@ -120,22 +119,22 @@ void GetBinds(const Array& args, bool compact, if (const auto* tensor_node = x.as()) { auto x_ref = GetRef(tensor_node); if (out_binds->find(x_ref) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + auto buf = + BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); out_binds->Set(x_ref, buf); out_arg_list->push_back(buf); } else { out_arg_list->push_back((*out_binds)[x_ref]); } } else if (x.as() || x.as()) { - out_arg_list->push_back(x); + out_arg_list->push_back(x); } else { - ICHECK(false) << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var"; + ICHECK(false) + << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var"; } } } - - transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -223,7 +222,7 @@ Array CreatePassList(bool enable_loop_partition, bool for_ pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); // PHASE 2 - if (!enable_loop_partition) { + if (enable_loop_partition) { pass_list.push_back(tir::transform::LoopPartition()); } @@ -259,7 +258,7 @@ IRModule LowerWithPassList(IRModule mod, Array pass_list) } IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { + const std::unordered_map& binds) { // Convert te schedule to IRModule Array out_arg_list; auto pass_ctx = transform::PassContext::Current(); @@ -274,10 +273,11 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const Map out_binds; GetBinds(args, compact, binds, &out_binds, &out_arg_list); - // build the function + // Build the function // At this point binds is only te::Tensors - - stmt = te::SchedulePostProcRewriteForTensorCore(stmt, sch, binds); // TODO(electriclilies): Should this be in here? Was in python but not C++ version. + stmt = te::SchedulePostProcRewriteForTensorCore( + stmt, sch, + binds); // TODO(electriclilies): Should this be in here? Was in python but not C++ version. tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); @@ -299,11 +299,10 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") c_binds.insert(std::pair(kv.first, kv.second)); } } - IRModule mod = ScheduleToModule(sch, args, name, c_binds); - return mod; + IRModule mod = ScheduleToModule(sch, args, name, c_binds); + return mod; }); - IRModule LowerModule(IRModule mod, bool enable_loop_partition) { auto pass_list = CreatePassList(enable_loop_partition, false); return LowerWithPassList(mod, pass_list); @@ -314,8 +313,8 @@ TVM_REGISTER_GLOBAL("driver.lower_module") return LowerModule(mod, enable_loop_partition); }); - -IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool enable_loop_partition) { +IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, + bool enable_loop_partition) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); @@ -336,10 +335,9 @@ TVM_REGISTER_GLOBAL("driver.lower_primfunc") return LowerPrimFunc(func, name, enable_loop_partition); }); - IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool enable_loop_partition) { - + const std::unordered_map& binds, + bool enable_loop_partition) { Array ref_args; for (auto x : args) { ref_args.push_back(x); @@ -348,7 +346,8 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const st } IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool enable_loop_partition) { + const std::unordered_map& binds, + bool enable_loop_partition) { // Get the legacy TE pass list IRModule mod = ScheduleToModule(sch, args, name, binds); auto pass_list = CreatePassList(enable_loop_partition, true); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7dcace457c97..472971c0d521 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -770,7 +770,8 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); + cache_node->funcs = + tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); } value->cached_func = CachedFunc(cache_node); return value; From 4d7a6021dc1f34e88a6610a337b29c1d38793488 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 27 May 2021 20:38:20 -0700 Subject: [PATCH 24/52] fix typos --- python/tvm/driver/build_module.py | 6 +++--- tests/python/unittest/test_lower_build.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index e0a1ac02ee4a..df2e67522963 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -69,7 +69,7 @@ def lower( args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, - simple_mode: bool = False, + enable_loop_partition: bool = True, ) -> IRModule: """Lowering step before build into target. @@ -104,7 +104,7 @@ def lower( if isinstance(input, PrimFunc): return ffi.lower_primfunc(input) if isinstance(input, schedule.Schedule): - return ffi.lower_schedule(input, args, name, binds, simple_mode) + return ffi.lower_schedule(input, args, name, binds, enable_loop_partition) raise ValueError( "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) ) @@ -345,7 +345,7 @@ def build( target_host = Target(target_host) if ( target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c" - and target_host.attrs.get("system-lib", 0).value == 1 + and target_host.attrs.get("system-lib", 0) == 1 ): if target_host.kind.name == "c": create_csource_crt_metadata_module = tvm._ffi.get_global_func( diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 3742df7f3b8a..4505a7bed244 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -82,7 +82,6 @@ def test_lower_build_te_schedule(): def test_lower_build_tir_func(): # check lowering - print("Type of matmul: ", type(matmul)) ir_mod = tvm.lower(matmul) tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) # check building From 2ed255d60bda52e07e7d92dc5449e36e44d12927 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 09:42:28 -0700 Subject: [PATCH 25/52] fix lint --- python/tvm/autotvm/record.py | 2 +- python/tvm/driver/build_module.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 4f11aea2911f..72c402b540c0 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -344,7 +344,7 @@ def pick_best(in_file, out_file): if args.ir: with inp.target: - print(lower(s, arg_bufs, simple_mode=True)) + print(lower(s, arg_bufs, enable_loop_partition=False)) if args.code: with inp.target: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index df2e67522963..6eb0de1f1dba 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -65,7 +65,7 @@ def schedule_to_module( def lower( - input: Union[schedule.Schedule, PrimFunc, IRModule], + inputs: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -75,7 +75,7 @@ def lower( Parameters ---------- - input : Union[schedule.Schedule, PrimFunc, IRModule] + inputs : Union[schedule.Schedule, PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] @@ -99,12 +99,12 @@ def lower( m : IRModule The result IRModule """ - if isinstance(input, IRModule): - return ffi.lower_module(input) - if isinstance(input, PrimFunc): - return ffi.lower_primfunc(input) - if isinstance(input, schedule.Schedule): - return ffi.lower_schedule(input, args, name, binds, enable_loop_partition) + if isinstance(inputs, IRModule): + return ffi.lower_module(inputs) + if isinstance(inputs, PrimFunc): + return ffi.lower_primfunc(inputs) + if isinstance(inputs, schedule.Schedule): + return ffi.lower_schedule(inputs, args, name, binds, enable_loop_partition) raise ValueError( "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) ) From cec345623d49104b8e393d1f4b00f2b6583e0f72 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 10:20:05 -0700 Subject: [PATCH 26/52] fix calls to lower in build_module_test --- tests/cpp/build_module_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 8cc5c4bc0a3a..204a824f9248 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = lower(s, args, "func", binds); + auto lowered = LowerSchedule(s, args, "func", binds); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -116,8 +116,8 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = lower(s1, args1, "elemwise_add", binds); - auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); From b92dfe35e72e7a17256f7cccbb49e9187799c446 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 12:27:13 -0700 Subject: [PATCH 27/52] change enable_loop_partition back to simple_mode for consistency with rest of codebase --- include/tvm/driver/driver_api.h | 16 ++++++++-------- python/tvm/autotvm/record.py | 2 +- python/tvm/driver/build_module.py | 4 ++-- src/driver/driver_api.cc | 30 +++++++++++++++--------------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 9aef10d67afa..aaeafd257278 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -47,20 +47,20 @@ namespace tvm { /*! * \brief Build an IRModule given a module, args and binds * \param mod The IRmodule to lower - * \param enable_loop_partition Enables the loop partition pass. Defaults to true. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ -TVM_DLL IRModule LowerModule(IRModule mod, bool enable_loop_partition = true); +TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); /*! * \brief Build an IRModule given a module, args and binds * \param func The PrimFunc to lower * \param name The name of the lowered function. - * \param enable_loop_partition Enables the loop partition pass. Defaults to true. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, - bool enable_loop_partition = true); + bool simple_mode = false); /*! * \brief Build an IRModule given a schedule, args and binds @@ -68,14 +68,14 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. - * \param enable_loop_partition Enables the loop partition pass. Defaults to true. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, - bool enable_loop_partition = true); + bool simple_mode = false); /*! * \brief Build an IRModule given a schedule, args and binds @@ -83,13 +83,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars) * \param name The name of the lowered function. * \param binds Buffer assignments. - * \param enable_loop_partition Enables the loop partition pass. Defaults to true. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, - bool enable_loop_partition = true); + bool simple_mode = false); /*! * \brief Create an IRModule out of a Schedule diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 72c402b540c0..4f11aea2911f 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -344,7 +344,7 @@ def pick_best(in_file, out_file): if args.ir: with inp.target: - print(lower(s, arg_bufs, enable_loop_partition=False)) + print(lower(s, arg_bufs, simple_mode=True)) if args.code: with inp.target: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 6eb0de1f1dba..e9fbc184d37e 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -69,7 +69,7 @@ def lower( args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, - enable_loop_partition: bool = True, + simple_mode: bool = False, ) -> IRModule: """Lowering step before build into target. @@ -104,7 +104,7 @@ def lower( if isinstance(inputs, PrimFunc): return ffi.lower_primfunc(inputs) if isinstance(inputs, schedule.Schedule): - return ffi.lower_schedule(inputs, args, name, binds, enable_loop_partition) + return ffi.lower_schedule(inputs, args, name, binds, simple_modes) raise ValueError( "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ac887ed1069e..a2fc6b8dcb87 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -154,7 +154,7 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -Array CreatePassList(bool enable_loop_partition, bool for_te_schedule) { +Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { auto pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); @@ -222,7 +222,7 @@ Array CreatePassList(bool enable_loop_partition, bool for_ pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); // PHASE 2 - if (enable_loop_partition) { + if (!disable_loop_partition) { pass_list.push_back(tir::transform::LoopPartition()); } @@ -303,18 +303,18 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") return mod; }); -IRModule LowerModule(IRModule mod, bool enable_loop_partition) { - auto pass_list = CreatePassList(enable_loop_partition, false); +IRModule LowerModule(IRModule mod, bool simple_mode) { + auto pass_list = CreatePassList(simple_mode, false); return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower_module") - .set_body_typed([](IRModule mod, bool enable_loop_partition) { - return LowerModule(mod, enable_loop_partition); + .set_body_typed([](IRModule mod, bool simple_mode) { + return LowerModule(mod, simple_mode); }); IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, - bool enable_loop_partition) { + bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); @@ -326,18 +326,18 @@ IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, IRModule mod = IRModule(Map({{GlobalVar(name), f}})); // Get the pass list - auto pass_list = CreatePassList(enable_loop_partition, false); + auto pass_list = CreatePassList(simple_mode, false); return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower_primfunc") - .set_body_typed([](te::PrimFunc func, const String& name, bool enable_loop_partition) { - return LowerPrimFunc(func, name, enable_loop_partition); + .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { + return LowerPrimFunc(func, name, simple_mode); }); IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, - bool enable_loop_partition) { + bool simple_mode) { Array ref_args; for (auto x : args) { ref_args.push_back(x); @@ -347,16 +347,16 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const st IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, - bool enable_loop_partition) { + bool simple_mode) { // Get the legacy TE pass list IRModule mod = ScheduleToModule(sch, args, name, binds); - auto pass_list = CreatePassList(enable_loop_partition, true); + auto pass_list = CreatePassList(simple_mode, true); return LowerWithPassList(mod, pass_list); } TVM_REGISTER_GLOBAL("driver.lower_schedule") .set_body_typed([](te::Schedule sch, const Array& args, const String& name, - const Map& binds, bool enable_loop_partition) { + const Map& binds, bool simple_mode) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; if (binds.get() != NULL) { @@ -364,7 +364,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert(std::pair(kv.first, kv.second)); } } - return LowerSchedule(sch, args, name, c_binds, enable_loop_partition); + return LowerSchedule(sch, args, name, c_binds, simple_mode); }); std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, From 0d69397bcdc45f5a6ff4fa525d118f1f188a110a Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 13:03:35 -0700 Subject: [PATCH 28/52] clang format --- src/driver/driver_api.cc | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index a2fc6b8dcb87..2394782c3b6d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -308,13 +308,11 @@ IRModule LowerModule(IRModule mod, bool simple_mode) { return LowerWithPassList(mod, pass_list); } -TVM_REGISTER_GLOBAL("driver.lower_module") - .set_body_typed([](IRModule mod, bool simple_mode) { - return LowerModule(mod, simple_mode); - }); +TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { + return LowerModule(mod, simple_mode); +}); -IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, - bool simple_mode) { +IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode) { auto pass_ctx = transform::PassContext::Current(); auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); @@ -336,8 +334,7 @@ TVM_REGISTER_GLOBAL("driver.lower_primfunc") }); IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, - bool simple_mode) { + const std::unordered_map& binds, bool simple_mode) { Array ref_args; for (auto x : args) { ref_args.push_back(x); @@ -346,8 +343,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const st } IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, - bool simple_mode) { + const std::unordered_map& binds, bool simple_mode) { // Get the legacy TE pass list IRModule mod = ScheduleToModule(sch, args, name, binds); auto pass_list = CreatePassList(simple_mode, true); From 30ad4075b88219e7f9a6abf472cc29f76812be64 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 14:19:58 -0700 Subject: [PATCH 29/52] fix typo --- python/tvm/driver/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index e9fbc184d37e..9217b40280ca 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -104,7 +104,7 @@ def lower( if isinstance(inputs, PrimFunc): return ffi.lower_primfunc(inputs) if isinstance(inputs, schedule.Schedule): - return ffi.lower_schedule(inputs, args, name, binds, simple_modes) + return ffi.lower_schedule(inputs, args, name, binds, simple_mode) raise ValueError( "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) ) From 51bcd3f7b9b30bdc0129113a7af0e06bde565bc0 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 14:34:16 -0700 Subject: [PATCH 30/52] retrigger From eddec7b25804fbafa4b3e050184afae7de3ba0fe Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 28 May 2021 16:01:12 -0700 Subject: [PATCH 31/52] fix calls to ffi lower --- python/tvm/driver/build_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 9217b40280ca..aebe4fdde4e0 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -100,9 +100,9 @@ def lower( The result IRModule """ if isinstance(inputs, IRModule): - return ffi.lower_module(inputs) + return ffi.lower_module(inputs, simple_mode) if isinstance(inputs, PrimFunc): - return ffi.lower_primfunc(inputs) + return ffi.lower_primfunc(inputs, simple_mode) if isinstance(inputs, schedule.Schedule): return ffi.lower_schedule(inputs, args, name, binds, simple_mode) raise ValueError( From 1bb8b9daf316eab0467df4c8ebfeabed9fb742c2 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 12:55:36 -0700 Subject: [PATCH 32/52] Add get binds to the FFI --- python/tvm/autotvm/feature.py | 2 +- python/tvm/driver/build_module.py | 23 +++++++++++++++++++++++ src/driver/driver_api.cc | 23 ++++++++++++++++++++++- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index dff0f098d84a..6c3d81f3c63d 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, binds) + binds, _ = build_module.get_binds(args, False, binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index aebe4fdde4e0..753936c8de32 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -40,6 +40,29 @@ from . import _ffi_api as ffi +def get_binds(args, compact=False, binds=None): + """Internal function to get binds and arg_list given arguments. + Parameters + ---------- + args : list of Buffer or Tensor or Var + The argument lists to the function. + compact : bool + If the statement has already bound to a compact buffer. + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + Dictionary that maps the Tensor to Buffer which specified the data layout + requirement of the function. By default, a new compact buffer is created + for each tensor in the argument. + Returns + ------- + binds: dict + The bind specification + arg_list: list + The list of symbolic buffers of arguments. + """ + out_arr = ffi.get_binds(args, compact, binds) + return out_arr[0], out_arr[1] + + def schedule_to_module( sch: schedule.Schedule, args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2394782c3b6d..b46b6d402e8b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -135,6 +135,27 @@ void GetBinds(const Array& args, bool compact, } } +TVM_REGISTER_GLOBAL("driver.get_binds") + .set_body_typed([](const Array& args, bool compact, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != NULL) { + for (auto kv : binds) { + c_binds.insert(std::pair(kv.first, kv.second)); + } + } + Map out_binds; + Array out_arg_list; + GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); + + // TODO(electriclilies): is there a way to return a pair? + Array out_arr; + out_arr.push_back(out_binds); + out_arr.push_back(out_arg_list); + return out_arr; + }); + transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -344,8 +365,8 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const st IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { - // Get the legacy TE pass list IRModule mod = ScheduleToModule(sch, args, name, binds); + // Get the legacy TE pass list auto pass_list = CreatePassList(simple_mode, true); return LowerWithPassList(mod, pass_list); } From 4b8a529a6bc19add95f87d097f5e6dba593bb014 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 13:20:33 -0700 Subject: [PATCH 33/52] fix black --- python/tvm/driver/tvmc/compiler.py | 2 +- tests/python/frontend/tensorflow2/test_functional_models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 0cc26872cd47..071474a31594 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -38,7 +38,7 @@ @register_parser def add_compile_parser(subparsers): - """ Include parser for 'compile' subcommand """ + """Include parser for 'compile' subcommand""" parser = subparsers.add_parser("compile", help="compile a model.") parser.set_defaults(func=drive_compile) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 40d42a28025a..fe31a753746c 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -63,7 +63,7 @@ def run_model_graph(TestClass): def test_add_one(): class AddOne(tf.Module): - """ simple function to test x=x+1; scalar as input""" + """simple function to test x=x+1; scalar as input""" def get_input(self): return np.array(1.0, dtype="float32") From 6115f3c4c6902000cacb418914f5c1aae3842d26 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 13:35:33 -0700 Subject: [PATCH 34/52] comment out SchedulePostProcRewriteForTensorCore --- src/driver/driver_api.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index b46b6d402e8b..73ead15fbce8 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -296,9 +296,8 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const // Build the function // At this point binds is only te::Tensors - stmt = te::SchedulePostProcRewriteForTensorCore( - stmt, sch, - binds); // TODO(electriclilies): Should this be in here? Was in python but not C++ version. + // TODO(electriclilies): Should this be in here? Was in python but not C++ version. + // stmt = te::SchedulePostProcRewriteForTensorCore(stmt, sch,binds); tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); From b9c8cb6fc56a9b90b3e4cb834a9e6c1c0dba6fe8 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 14:12:39 -0700 Subject: [PATCH 35/52] DELETE schedule postproc rewrite for tensorcore --- include/tvm/te/schedule_pass.h | 12 - src/driver/driver_api.cc | 2 - ...hedule_postproc_rewrite_for_tensor_core.cc | 1124 ----------------- 3 files changed, 1138 deletions(-) delete mode 100644 src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index 32e74f6ef9d5..0ba7421ce409 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -88,18 +88,6 @@ bool VerifyCompactBuffer(const Stmt& stmt); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); -/*! - * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. - * - * \param stmt The stmt to be trasnformed. - * \param schedule The original schedule. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \return Transformed stmt. - */ -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer); - /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create * a PrimFunc that can then be used for further TIR optimizations. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 73ead15fbce8..0c75491f16a9 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -296,8 +296,6 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const // Build the function // At this point binds is only te::Tensors - // TODO(electriclilies): Should this be in here? Was in python but not C++ version. - // stmt = te::SchedulePostProcRewriteForTensorCore(stmt, sch,binds); tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc deleted file mode 100644 index 951bd6c18706..000000000000 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ /dev/null @@ -1,1124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file schedule_postproc_rewrite_for_tensor_core.cc - * - * \brief Rewrite the Stmt generated by ScheduleOps - * to accomondate tensorcore. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace te { - -using namespace te; -using runtime::StorageRank; -using runtime::StorageScope; -using runtime::ThreadScope; - -struct Tile { - int m{-1}; - int n{-1}; - int k{-1}; -}; - -std::string simplify_name(std::string input) { - auto pos = input.find("."); - if (pos != std::string::npos) { - return input.substr(0, pos); - } else { - return input; - } -} - -PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { - auto cast = input.as(); - if (cast == nullptr) { - return input; - } else if (cast->dtype == target_type) { - return cast->value; - } - return PrimExpr(); -} - -// MMAMatcher matches C = Cast(A)*Cast(B)+C, -// where A & B are fp16/int8 local buffers, -// and C is fp32/int32 local buffer. -class MMAMatcher : public StmtVisitor { - public: - explicit MMAMatcher(Map extern_buffer) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::pragma_tensor_core) { - tensor_core_on_ = true; - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else { - StmtVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtVisitor::VisitStmt_(op); - auto it = buf_map_.find(Downcast(op->producer)); - if (it == buf_map_.end()) { - return; - } - const BufferInfo& bi = it->second; - if (bi.released) { - return; - } - if (tensor_core_on_ && mma_sync_match_(op, bi)) { - matched_ = true; - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - if (!buf_map_.at(key).external) { - return; - } - this->VisitStmt(op->body); - } else { - BufferInfo bi; - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - inline bool Matched() const { return matched_; } - - friend class ScheduleAnalyser; - friend class BufferAnalyser; - - private: - struct BufferInfo { - std::string name; - DataType dtype; - bool external{false}; - bool released{false}; - bool same_as(const BufferInfo& bi) { - if (this->dtype != bi.dtype) return false; - if (this->name != bi.name) return false; - if (this->external != bi.external) return false; - if (this->released != bi.released) return false; - return true; - } - }; - - // Check whether the storage scope is local - bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) { - auto tensor = Downcast(op->producer); - auto it = storage_scope_.find(tensor.get()); - if (it == storage_scope_.end()) { - return false; - } - const std::string& strkey = it->second; - if (strkey != "local") { - return false; - } - auto it1 = buf_map_.find(tensor); - if (it1 == buf_map_.end()) { - return false; - } - *bi = it1->second; - if (bi->released) { - return false; - } - return true; - } - - // Do the pattern matching - bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) { - auto* add = op->value.as(); - if (add == nullptr) { - return false; - } - - auto* load_c = add->a.as(); - BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || - !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { - return false; - } - - auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); - if (mul == nullptr) { - return false; - } - - auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); - auto load_a = load_a_expr.as(); - BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) || - !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); - auto load_b = load_b_expr.as(); - BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) || - !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - frag_reg_.insert(buffer_c.name); - frag_reg_.insert(buffer_a.name); - frag_reg_.insert(buffer_b.name); - buf_name_.insert(std::make_pair(load_a, buffer_a.name)); - buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); - - return true; - } - - std::unordered_map buf_map_; - std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; - std::unordered_set frag_reg_; - bool matched_{false}; - bool tensor_core_on_{false}; -}; - -// BodyVisitor visits the body stmt of original ComputeOp -// to get the access indices of input matrices, -// if it is recognized as matrix multiply. -class BodyVisitor : public StmtExprVisitor { - public: - BodyVisitor() {} - - void VisitExpr_(const ReduceNode* op) final { - auto* comm_add = op->combiner->result[0].as(); - if (comm_add == nullptr || op->combiner->result.size() > 1) { - return; - } - for (PrimExpr source : op->source) { - auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); - auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); - if (mul_0 == nullptr && mul_1 == nullptr) { - continue; - } - - tensorcore_candidate_ = true; - StmtExprVisitor::VisitExpr(source); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); - } - - friend class ScheduleAnalyser; - - private: - std::unordered_map> args_; - bool tensorcore_candidate_{false}; -}; - -// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major -class ScheduleAnalyser { - public: - explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} - - bool MatrixIdentify(Schedule schedule) { - // TODO(minmin): handle the case where MatMul is not the output stage - for (Operation output : schedule->outputs) { - const ComputeOpNode* compute = output.as(); - if (compute == nullptr) { - // Not a ComputeOp - continue; - } - auto axis = compute->axis; - auto reduce_axis = compute->reduce_axis; - if (axis.size() < 2 || reduce_axis.size() != 1) { - continue; - } - const VarNode* axis_var[2]; - const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size() - 2]->var.as(); - axis_var[1] = axis[axis.size() - 1]->var.as(); - reduce_axis_var = reduce_axis[0]->var.as(); - - BodyVisitor body_visitor; - for (PrimExpr expr : compute->body) { - body_visitor(expr); - } - if (!body_visitor.tensorcore_candidate_) { - continue; - } - for (auto iter : body_visitor.args_) { - auto name = iter.first; - auto args = iter.second; - if (args.size() < 2) { - continue; - } - const VarNode* var0 = args[args.size() - 2].as(); - const VarNode* var1 = args[args.size() - 1].as(); - if (var0 == nullptr || var1 == nullptr) { - continue; - } - std::string matrix_abc, major; - if (var0 == reduce_axis_var && var1 == axis_var[1]) { - matrix_abc = "matrix_a"; - major = "col_major"; - } else if (var0 == reduce_axis_var && var1 == axis_var[0]) { - matrix_abc = "matrix_b"; - major = "row_major"; - } else if (var0 == axis_var[1] && var1 == reduce_axis_var) { - matrix_abc = "matrix_a"; - major = "row_major"; - } else if (var0 == axis_var[0] && var1 == reduce_axis_var) { - matrix_abc = "matrix_b"; - major = "col_major"; - } - matrix_abc_.insert(std::make_pair(name, matrix_abc)); - matrix_major_.insert(std::make_pair(name, major)); - } - matrix_abc_.insert(std::make_pair(compute->name, "accumulator")); - matrix_major_.insert(std::make_pair(compute->name, "col_major")); - } - - for (auto& mma_sync : mma_sync_) { - auto& operands = mma_sync.second; - auto* load_a = operands[0].as(); - auto* load_b = operands[1].as(); - auto input0 = simplify_name(buf_name_.find(load_a)->second); - auto input1 = simplify_name(buf_name_.find(load_b)->second); - auto it0 = matrix_abc_.find(input0); - auto it1 = matrix_abc_.find(input1); - - if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) { - return false; - } - if (it0->second == "matrix_a" && it1->second == "matrix_b") { - return true; - } else if (it0->second == "matrix_b" && it1->second == "matrix_a") { - mma_sync.second = Array{operands[1], operands[0], operands[2]}; - } else { - return false; - } - } - return true; - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; -}; - -// IndexVisitor visits access index of fragment -// to record variable for loop scaling -class IndexVisitor : public StmtExprVisitor { - public: - IndexVisitor() {} - - void VisitExpr_(const VarNode* op) final { - loop_scaling_.insert(std::make_pair(op, scaling_factor_)); - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map loop_scaling_; - unsigned scaling_factor_{0}; -}; - -// BufferAnalyser gets buffer info, -// e.g. thread tile and warp tile, for TensorCore CodeGen -class BufferAnalyser : public StmtExprVisitor { - public: - explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - frag_reg_(mma_matcher.frag_reg_) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.strides = kv.second->strides; - bi.shape = kv.second->shape; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { - if (const IntImmNode* value = op->value.as()) { - thread_extent_.insert( - std::make_pair(op->node.as()->var->name_hint, value->value)); - } - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else if (op->attr_key == tir::attr::buffer_dim_align) { - te::Tensor tensor = Downcast(op->node); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto& vinfo = dim_align_[tensor]; - size_t dim = tuple->args[0].as()->value; - if (dim >= vinfo.size()) { - vinfo.resize(dim + 1); - } - vinfo[dim].align_factor = tuple->args[1].as()->value; - vinfo[dim].align_offset = tuple->args[2].as()->value; - this->VisitStmt(op->body); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtExprVisitor::VisitStmt_(op); - auto key = Downcast(op->producer); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(key->GetNameHint())) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(key->GetNameHint(), strides)); - - if (frag_reg_.count(bi.name)) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_load_.insert(std::make_pair(op, dst)); - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - std::vector tile_size; - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - tile_size.push_back(shape->value); - index_visitor.scaling_factor_ = shape->value; - } else { - invalid_ = true; - return; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - - std::string input_name = simplify_name(bi.name); - auto it = matrix_abc_.find(input_name); - auto it2 = matrix_major_.find(input_name); - bool ret = true; - if (it != matrix_abc_.end() && it2 != matrix_major_.end()) { - if (it->second == "matrix_a" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.m, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.n, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "accumulator") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (!ret) { - invalid_ = true; - return; - } - } - } - - const ProducerLoadNode* value = op->value.as(); - // TODO(tvm-team): string matching is dangerous, consider other means. - if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_store_.insert(std::make_pair(op, dst)); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - - auto tensor = Downcast(op->producer); - auto it = buf_map_.find(tensor); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(tensor->op->name)) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(tensor->GetNameHint(), strides)); - - if (!frag_reg_.count(bi.name)) { - return; - } - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - index_visitor.scaling_factor_ = shape->value; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external); - this->VisitStmt(op->body); - } else { - // create a buffer entry - BufferInfo bi; - - bi.bounds = op->bounds; - Array shape; - for (auto r : bi.bounds) { - shape.push_back(r->extent); - } - - Array strides; - if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; - const std::vector& avec = dim_align_[key]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = analyzer_.Simplify(stride); - } - rstrides.push_back(stride); - stride = stride * shape[dim]; - } - strides = Array(rstrides.rbegin(), rstrides.rend()); - } - - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - bi.strides = strides; - bi.shape = shape; - - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - // Derive warp tile from thread tile, - // and check whether it is qualified for TensorCore. - bool QualifiedForTensorCore() { - if (invalid_) { - return false; - } - auto itx = thread_extent_.find("threadIdx.x"); - if (itx == thread_extent_.end()) { - return false; - } - int warp_threads_x = itx->second; - warp_tile_.m = warp_threads_x * thread_tile_.m; - warp_threads_y_ = 32 / warp_threads_x; - auto ity = thread_extent_.find("threadIdx.y"); - if (ity == thread_extent_.end()) { - return false; - } - if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) { - return false; - } - warp_tile_.n = warp_threads_y_ * thread_tile_.n; - warp_tile_.k = thread_tile_.k; - return supported_warp_tile_(); - } - - friend class TensorCoreIRMutator; - - private: - struct DimAlignInfo { - int align_factor{0}; - int align_offset{0}; - }; - - struct BufferInfo { - std::string name; - DataType dtype; - Array strides; - Array shape; - Region bounds; - bool external{false}; - bool released{false}; - inline Array RelIndex(Array args) const { - if (bounds.size() != 0) { - Array index; - ICHECK_EQ(bounds.size(), args.size()); - for (size_t i = 0; i < bounds.size(); ++i) { - index.push_back(args[i] - bounds[i]->min); - } - return index; - } else { - return args; - } - } - }; - - bool assign_or_check_(int* dst, int src) { - if (*dst <= 0) { - *dst = src; - return true; - } - if (*dst == src) { - return true; - } - return false; - } - - bool supported_warp_tile_() { - if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { - return true; - } - - return false; - } - - std::unordered_map buf_map_; - std::unordered_map> dim_align_; - std::unordered_map storage_scope_; - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_set frag_reg_; - std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map thread_extent_; - IndexVisitor index_visitor; - Tile warp_tile_; - Tile thread_tile_; - arith::Analyzer analyzer_; - int warp_threads_y_{-1}; - bool invalid_{false}; -}; - -// ThreadIdxMutator does the thread index unification inside a warp -class ThreadIdxMutator : public StmtExprMutator { - public: - explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr) { - if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImm(DataType::Int(32), 0); - return zero; - } - if (op->name_hint == "threadIdx.y") { - PrimExpr div = Div(expr, warp_y_); - PrimExpr mul = Mul(div, warp_y_); - return mul; - } - } - return expr; - } - - private: - PrimExpr warp_y_; -}; - -// TensorCoreIRMutator mutates the AST for TensorCore CodeGen -// based on tensor core intrinsics -class TensorCoreIRMutator : public StmtExprMutator { - public: - explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, - const BufferAnalyser& buffer_analyser) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} - - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - bounds_[key] = op->bounds; - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - if (!frag_reg_.count(key->GetNameHint())) { - return stmt; - } - - auto new_extents = get_tile_size_(simplify_name(key->GetNameHint())); - - Region new_bounds; - for (size_t i = 0; i < op->bounds.size() - 2; ++i) { - new_bounds.push_back(op->bounds[i]); - } - ICHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint(); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return ProducerRealize(op->producer, new_bounds, op->condition, op->body); - } - return stmt; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (op->attr_key == tir::attr::realize_scope) { - auto node = op->node.as(); - if (node != nullptr) { - if (!frag_reg_.count(node->name)) { - return stmt; - } - - auto it = matrix_abc_.find(simplify_name(node->name)); - ICHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); - Stmt body = this->VisitStmt(op->body); - return AttrStmt(op->node, op->attr_key, matrix_abc, body); - } - } - return stmt; - } - - Stmt VisitStmt_(const ProducerStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - auto it = mma_sync_.find(op); - if (it != mma_sync_.end()) { - const auto& operands = it->second; - PrimExpr a = operands[0]; - auto ca = a.as(); - PrimExpr b = operands[1]; - auto cb = b.as(); - PrimExpr c = operands[2]; - auto cc = c.as(); - - ObjectPtr buffer_node_a = make_object(); - ObjectPtr buffer_node_b = make_object(); - ObjectPtr buffer_node_c = make_object(); - - auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_bmma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } else { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_mma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } - }; - - auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call); - }; - - auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c); - }; - - return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b); - } - - auto it2 = frag_load_.find(op); - if (it2 != frag_load_.end()) { - PrimExpr dst = it2->second; - if (op->value.as() != nullptr || op->value.as() != nullptr) { - auto pload = dst.as(); - - auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); - } - - const ProducerLoadNode* value = op->value.as(); - ICHECK(value != nullptr) << "Can only load fragment from a buffer"; - - auto it = strides_.find(value->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - PrimExpr mutated_value = thread_idx_mutator(op->value); - // TODO(tvm-team) The extern function name seems to be a hack. - PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}); - - auto pload = dst.as(); - PrimExpr matrix_major; - auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint())); - ICHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << pload->producer->GetNameHint(); - if (iter2->second == "col_major") { - matrix_major = StringImm("col_major"); - } else if (iter2->second == "row_major") { - matrix_major = StringImm("row_major"); - } else { - LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); - } - - auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call); - } - - auto it3 = frag_store_.find(op); - if (it3 != frag_store_.end()) { - auto it = strides_.find(op->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - PrimExpr dst = it3->second; - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - dst = thread_idx_mutator(dst); - dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}); - - auto pload = op->value.as(); - - auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImm("col_major")})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call); - } - - return stmt; - } - - Stmt VisitStmt_(const ForNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - auto it = loop_scaling_.find(op->loop_var.get()); - if (it != loop_scaling_.end()) { - int scale_factor = it->second; - int scaled_extent_value = 1; - if (const IntImmNode* ori_extent = op->extent.as()) { - int ori_extent_value = ori_extent->value; - scaled_extent_value = ori_extent_value / scale_factor; - } - PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, - op->annotations); - } - } - return stmt; - } - - private: - Array get_tile_size_(const std::string& name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - ICHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; - } - - Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, - const ObjectPtr& buffer_node, - const std::function& call_back) { - auto tensor = Downcast(pload->producer); - auto it = bounds_.find(tensor); - ICHECK(it != bounds_.end()); - Array min_bound; - for (auto i : it->second) { - min_bound.push_back(i->min); - } - - ICHECK_GE(it->second.size(), 2); - Array shape; - for (size_t i = 0; i < it->second.size() - 2; ++i) { - shape.push_back(it->second[i]->extent); - } - auto tile_size = get_tile_size_(simplify_name(tensor->op->name)); - shape.push_back(tile_size[0]); - shape.push_back(tile_size[1]); - - Array strides; - for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = shape.size() - 1; j >= i; --j) { - stride = Mul(stride, shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - - PrimExpr elem_offset = IntImm(DataType::Int(32), 0); - ICHECK_EQ(pload->indices.size(), min_bound.size()); - for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i]))); - } - - auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); - ICHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name; - buffer_node->data = Var(tensor->op->name, DataType::Handle()); - buffer_node->name = tensor->op->name; - buffer_node->scope = "wmma." + it2->second; - buffer_node->dtype = tensor->dtype; - buffer_node->strides = strides; - buffer_node->shape = shape; - buffer_node->data_alignment = 1; - buffer_node->elem_offset = analyzer_.Simplify(elem_offset); - buffer_node->offset_factor = 1; - Buffer buffer(buffer_node); - - Array args; - for (size_t i = 0; i < pload->indices.size(); ++i) { - args.push_back(pload->indices[i]); - args.push_back(shape[i]); - } - auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args); - Array node = {buffer, tensor}; - return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); - } - - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map> strides_; - std::unordered_set frag_reg_; - std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map bounds_; - arith::Analyzer analyzer_; - Tile warp_tile_; - int warp_threads_y_{-1}; -}; - -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer) { - // Check if current lower target is CUDA - auto target = tvm::Target::Current(true); - if (target.defined() && target->kind->name != "cuda") { - return stmt; - } - - // Check if current runtime support GPU CUDA - Device dev{kDLCUDA, 0}; - auto api = tvm::runtime::DeviceAPI::Get(dev, true); - if (api == nullptr) { - return stmt; - } - - MMAMatcher mma_matcher(extern_buffer); - mma_matcher(stmt); - if (!mma_matcher.Matched()) { - return stmt; - } - - ScheduleAnalyser schedule_analyser(mma_matcher); - if (!schedule_analyser.MatrixIdentify(schedule)) { - return stmt; - } - - BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); - buffer_analyser(stmt); - if (!buffer_analyser.QualifiedForTensorCore()) { - return stmt; - } - - return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); -} - -TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") - .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { - return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); - }); - -} // namespace te -} // namespace tvm From 09d78063d77c8548f808321060729bec5cffc17a Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 14:23:33 -0700 Subject: [PATCH 36/52] Fix call fo lower_primfunc --- python/tvm/driver/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 753936c8de32..30154bb8405d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -125,7 +125,7 @@ def lower( if isinstance(inputs, IRModule): return ffi.lower_module(inputs, simple_mode) if isinstance(inputs, PrimFunc): - return ffi.lower_primfunc(inputs, simple_mode) + return ffi.lower_primfunc(inputs, name, simple_mode) if isinstance(inputs, schedule.Schedule): return ffi.lower_schedule(inputs, args, name, binds, simple_mode) raise ValueError( From 8bfc97e3f06575a0cb136c8fb17e8838e65676ce Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 15:00:10 -0700 Subject: [PATCH 37/52] Clean up comments --- src/driver/driver_api.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0c75491f16a9..003067fc0aa3 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -93,7 +93,6 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } -// comment to try to remove this void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { @@ -149,7 +148,8 @@ TVM_REGISTER_GLOBAL("driver.get_binds") Array out_arg_list; GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); - // TODO(electriclilies): is there a way to return a pair? + // TVM object system doesn't have a pair object, so we'll put both ret values in an array + // and return that. Array out_arr; out_arr.push_back(out_binds); out_arr.push_back(out_arg_list); @@ -217,9 +217,6 @@ Array CreatePassList(bool disable_loop_partition, bool for } // Construct the pass list, inserting the user provided passes at the end of the phase - // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase. - // The code is inconsistent with what passes are in which phase as well. For now I have coped the - // python behavior exactly. // PHASE 0 auto pass_list = user_lower_phase0; From 566de6831090800586509795751602542de7c964 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 15:15:52 -0700 Subject: [PATCH 38/52] change return of ffi get_binds --- src/driver/driver_api.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 003067fc0aa3..a0cc3d233780 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -150,9 +150,7 @@ TVM_REGISTER_GLOBAL("driver.get_binds") // TVM object system doesn't have a pair object, so we'll put both ret values in an array // and return that. - Array out_arr; - out_arr.push_back(out_binds); - out_arr.push_back(out_arg_list); + Array out_arr = {out_binds, out_arg_list}; return out_arr; }); From cc9458eb80d0994b6ca53e98d8e039611593d969 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 1 Jun 2021 17:43:40 -0700 Subject: [PATCH 39/52] Fix off by1 --- src/driver/driver_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index a0cc3d233780..eae00fc406fa 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -198,7 +198,7 @@ Array CreatePassList(bool disable_loop_partition, bool for << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; int phase_num_val = phase_num->value; - CHECK_GT(phase_num_val, 0); + CHECK_GE(phase_num_val, 0); auto pass_node = phase_pass[1].as(); auto pass = GetRef(pass_node); From ca4683d416c1866f3f139cfb05840cf9c6c2d6ce Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 7 Jun 2021 14:33:21 -0700 Subject: [PATCH 40/52] respond to tristan's comments --- include/tvm/driver/driver_api.h | 12 ++++++------ python/tvm/driver/build_module.py | 2 +- src/driver/driver_api.cc | 5 +++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index aaeafd257278..f6c05dc19edd 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -53,7 +53,7 @@ namespace tvm { TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); /*! - * \brief Build an IRModule given a module, args and binds + * \brief Build an IRModule given a primfunc, args and binds * \param func The PrimFunc to lower * \param name The name of the lowered function. * \param simple_mode Disables the loop partition pass. Defaults to false. @@ -63,8 +63,8 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode = false); /*! - * \brief Build an IRModule given a schedule, args and binds - * \param sch The schedule to lower. + * \brief Build an IRModule given a TE schedule, args and binds + * \param sch The TE schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. @@ -78,8 +78,8 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, bool simple_mode = false); /*! - * \brief Build an IRModule given a schedule, args and binds - * \param sch The schedule to lower. + * \brief Build an IRModule given a TE schedule, args and binds + * \param sch The TE schedule to lower. * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars) * \param name The name of the lowered function. * \param binds Buffer assignments. @@ -92,7 +92,7 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, bool simple_mode = false); /*! - * \brief Create an IRModule out of a Schedule + * \brief Create an IRModule out of a TE Schedule (without applying lowering passes) * \param sch The schedule * \param args The arguments to the function. * \param name The name of the lowered function. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 30154bb8405d..2c00aca3bea1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -77,7 +77,7 @@ def schedule_to_module( args : list of Buffer or Tensor or Var The argument lists to the function. name : str - The name of result function. + The name of result function, default name is "main" binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information Returns diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index eae00fc406fa..9f1d2c646f98 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -128,8 +128,9 @@ void GetBinds(const Array& args, bool compact, } else if (x.as() || x.as()) { out_arg_list->push_back(x); } else { - ICHECK(false) - << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var"; + LOG(FATAL) + << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " + << "but got a " << typeid(x).name(); } } } From 00fcd7a8d5803dcf2d956c11980984edaef8772a Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 7 Jun 2021 15:03:57 -0700 Subject: [PATCH 41/52] update driver_api.h docs --- include/tvm/driver/driver_api.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index f6c05dc19edd..30ff4f8ef78c 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -45,7 +45,7 @@ namespace tvm { /*! - * \brief Build an IRModule given a module, args and binds + * \brief Build an IRModule given an input IRModule * \param mod The IRmodule to lower * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. @@ -53,7 +53,7 @@ namespace tvm { TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); /*! - * \brief Build an IRModule given a primfunc, args and binds + * \brief Build an IRModule given a primfunc and name * \param func The PrimFunc to lower * \param name The name of the lowered function. * \param simple_mode Disables the loop partition pass. Defaults to false. From 7a2b404276607e6da7a8a741e34f154763ea5a54 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 7 Jun 2021 15:28:55 -0700 Subject: [PATCH 42/52] Apply suggestions from code review Co-authored-by: Chris Sullivan --- include/tvm/driver/driver_api.h | 2 +- python/tvm/autotvm/feature.py | 2 +- src/driver/driver_api.cc | 13 ++++++------- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 30ff4f8ef78c..576286a57200 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -80,7 +80,7 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, /*! * \brief Build an IRModule given a TE schedule, args and binds * \param sch The TE schedule to lower. - * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars) + * \param args The arguments to the function (Array of Tensor, Buffer and Vars) * \param name The name of the lowered function. * \param binds Buffer assignments. * \param simple_mode Disables the loop partition pass. Defaults to false. diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index 6c3d81f3c63d..8d2591dce50b 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, False, binds) + binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9f1d2c646f98..bc8d3abcb415 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -140,9 +140,9 @@ TVM_REGISTER_GLOBAL("driver.get_binds") const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != NULL) { + if (binds.get() != nullptr) { for (auto kv : binds) { - c_binds.insert(std::pair(kv.first, kv.second)); + c_binds.insert({kv.first, kv.second}); } } Map out_binds; @@ -256,7 +256,6 @@ Array CreatePassList(bool disable_loop_partition, bool for pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); - // HoistIfThenElse pass_list.push_back(tir::transform::HoistIfThenElse()); // Add user-defined phase-3 passes @@ -308,9 +307,9 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != NULL) { + if (binds.get() != nullptr) { for (auto kv : binds) { - c_binds.insert(std::pair(kv.first, kv.second)); + c_binds.insert({kv.first, kv.second}); } } IRModule mod = ScheduleToModule(sch, args, name, c_binds); @@ -369,9 +368,9 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") const Map& binds, bool simple_mode) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != NULL) { + if (binds.get() != nullptr) { for (auto kv : binds) { - c_binds.insert(std::pair(kv.first, kv.second)); + c_binds.insert({kv.first, kv.second}); } } return LowerSchedule(sch, args, name, c_binds, simple_mode); From 8988368a5849e6635d5a3fbd529b3a9cc6956703 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 7 Jun 2021 15:32:30 -0700 Subject: [PATCH 43/52] remove relay.backend.lower --- python/tvm/driver/build_module.py | 4 +-- python/tvm/relay/backend/_backend.py | 39 ---------------------------- src/relay/backend/compile_engine.cc | 15 +++++------ 3 files changed, 8 insertions(+), 50 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 2c00aca3bea1..15196a106f51 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -59,8 +59,8 @@ def get_binds(args, compact=False, binds=None): arg_list: list The list of symbolic buffers of arguments. """ - out_arr = ffi.get_binds(args, compact, binds) - return out_arr[0], out_arr[1] + binds, arg_list = ffi.get_binds(args, compact, binds) + return binds, arg_list def schedule_to_module( diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 9460e23a5357..7378ed6beb8a 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -20,45 +20,6 @@ from tvm.target import Target -@tvm._ffi.register_func("relay.backend.lower") -def lower(sch, inputs, func_name, source_func): - """Backend function for lowering. - - Parameters - ---------- - sch : tvm.te.Schedule - The schedule. - - inputs : List[tvm.te.Tensor] - The inputs to the function. - - func_name : str - The name of the function. - - source-func : tvm.relay.Function - The source function to be lowered. - - Returns - ------- - mod : tvm.IRModule - The result of lowering. - """ - # pylint: disable=broad-except, import-outside-toplevel - import traceback - - try: - f = tvm.driver.lower(sch, inputs, name=func_name) - # logging.debug("lower function %s", func_name) - # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) - except Exception: - msg = traceback.format_exc() - msg += "Error during compile function\n" - msg += "-----------------------------\n" - msg += source_func.astext() - raise RuntimeError(msg) - return f - - @tvm._ffi.register_func("relay.backend.build") def build(mod, target, target_host=None): """Backend build function. diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 472971c0d521..7043d82e5d02 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -763,16 +763,13 @@ class CompileEngineImpl : public CompileEngineNode { all_args.push_back(arg); } // lower the function - if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); - } else { - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); - std::unordered_map binds; - cache_node->funcs = - tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); - } + std::unordered_map binds; + cache_node->funcs = + tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); + value->cached_func = CachedFunc(cache_node); return value; } From 38a121baa679450d220f7959ca05fdae3d655d8b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 8 Jun 2021 11:32:25 -0700 Subject: [PATCH 44/52] Respond to feedback --- include/tvm/driver/driver_api.h | 9 +++-- python/tvm/driver/build_module.py | 14 +++---- src/driver/driver_api.cc | 66 +++++++++++++++---------------- 3 files changed, 46 insertions(+), 43 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 576286a57200..e8004bbb6b0f 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -63,7 +63,8 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode = false); /*! - * \brief Build an IRModule given a TE schedule, args and binds + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. * \param sch The TE schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. @@ -78,7 +79,8 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, bool simple_mode = false); /*! - * \brief Build an IRModule given a TE schedule, args and binds + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. * \param sch The TE schedule to lower. * \param args The arguments to the function (Array of Tensor, Buffer and Vars) * \param name The name of the lowered function. @@ -92,7 +94,8 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, bool simple_mode = false); /*! - * \brief Create an IRModule out of a TE Schedule (without applying lowering passes) + * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want + * to apply lowering passes as well, use LowerSchedule. * \param sch The schedule * \param args The arguments to the function. * \param name The name of the lowered function. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 15196a106f51..d2c54d4804da 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -88,7 +88,7 @@ def schedule_to_module( def lower( - inputs: Union[schedule.Schedule, PrimFunc, IRModule], + inp: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -122,12 +122,12 @@ def lower( m : IRModule The result IRModule """ - if isinstance(inputs, IRModule): - return ffi.lower_module(inputs, simple_mode) - if isinstance(inputs, PrimFunc): - return ffi.lower_primfunc(inputs, name, simple_mode) - if isinstance(inputs, schedule.Schedule): - return ffi.lower_schedule(inputs, args, name, binds, simple_mode) + if isinstance(inp, IRModule): + return ffi.lower_module(input, simple_mode) + if isinstance(inp, PrimFunc): + return ffi.lower_primfunc(input, name, simple_mode) + if isinstance(inp, schedule.Schedule): + return ffi.lower_schedule(input, args, name, binds, simple_mode) raise ValueError( "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bc8d3abcb415..d5635724177c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -115,10 +115,10 @@ void GetBinds(const Array& args, bool compact, *out_binds = binds; for (const ObjectRef& x : args) { - if (const auto* tensor_node = x.as()) { - auto x_ref = GetRef(tensor_node); + if (const te::TensorNode* tensor_node = x.as()) { + te::Tensor x_ref = GetRef(tensor_node); if (out_binds->find(x_ref) == out_binds->end()) { - auto buf = + tir::Buffer buf = BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); out_binds->Set(x_ref, buf); out_arg_list->push_back(buf); @@ -175,34 +175,34 @@ transform::Pass Filter(FCond fcond) { } Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { - auto pass_ctx = transform::PassContext::Current(); + transform::PassContext pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); // Get any user-added passes - auto add_lower_pass = + Array> add_lower_pass = pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) .value(); - auto user_lower_phase0 = Array(); - auto user_lower_phase1 = Array(); - auto user_lower_phase2 = Array(); - auto user_lower_phase3 = Array(); + Array user_lower_phase0 = Array(); + Array user_lower_phase1 = Array(); + Array user_lower_phase2 = Array(); + Array user_lower_phase3 = Array(); // phase pasees is of the form // [[phase_number, pass], [phase_number, pass]... ] - for (auto phase_pass : add_lower_pass) { - auto phase_num = phase_pass[0].as(); + for (Array phase_pass : add_lower_pass) { + const IntImmNode* phase_num = phase_pass[0].as(); ICHECK(phase_num) << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; int phase_num_val = phase_num->value; CHECK_GE(phase_num_val, 0); - auto pass_node = phase_pass[1].as(); - auto pass = GetRef(pass_node); + const tvm::transform::PassNode* pass_node = phase_pass[1].as(); + tvm::transform::Pass pass = GetRef(pass_node); // Copy the pass into the correct phase if (phase_num_val == 0) { user_lower_phase0.push_back(pass); @@ -218,7 +218,7 @@ Array CreatePassList(bool disable_loop_partition, bool for // Construct the pass list, inserting the user provided passes at the end of the phase // PHASE 0 - auto pass_list = user_lower_phase0; + Array pass_list = user_lower_phase0; // PHASE 1 if (for_te_schedule) { @@ -268,7 +268,7 @@ Array CreatePassList(bool disable_loop_partition, bool for } IRModule LowerWithPassList(IRModule mod, Array pass_list) { - auto optimize = transform::Sequential(pass_list); + auto optimize = tvm::transform::Sequential(pass_list); mod = optimize(std::move(mod)); return mod; } @@ -277,13 +277,13 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const const std::unordered_map& binds) { // Convert te schedule to IRModule Array out_arg_list; - auto pass_ctx = transform::PassContext::Current(); + transform::PassContext pass_ctx = transform::PassContext::Current(); sch = sch.normalize(); // Before TIR transformation. - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); + Map bounds = te::InferBound(sch); + tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); bool compact = te::VerifyCompactBuffer(stmt); Map out_binds; @@ -312,22 +312,22 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") c_binds.insert({kv.first, kv.second}); } } - IRModule mod = ScheduleToModule(sch, args, name, c_binds); + IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds); return mod; }); IRModule LowerModule(IRModule mod, bool simple_mode) { - auto pass_list = CreatePassList(simple_mode, false); - return LowerWithPassList(mod, pass_list); + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); } TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { - return LowerModule(mod, simple_mode); + return LowerModule(std::move(mod), simple_mode); }); -IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode) { - auto pass_ctx = transform::PassContext::Current(); - auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); +IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); @@ -337,29 +337,29 @@ IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool si IRModule mod = IRModule(Map({{GlobalVar(name), f}})); // Get the pass list - auto pass_list = CreatePassList(simple_mode, false); - return LowerWithPassList(mod, pass_list); + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); } TVM_REGISTER_GLOBAL("driver.lower_primfunc") .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { - return LowerPrimFunc(func, name, simple_mode); + return LowerPrimFunc(std::move(func), name, simple_mode); }); IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { Array ref_args; - for (auto x : args) { + for (ObjectRef x : args) { ref_args.push_back(x); } - return LowerSchedule(sch, ref_args, name, binds); + return LowerSchedule(std::move(sch), ref_args, name, binds); } IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, bool simple_mode) { - IRModule mod = ScheduleToModule(sch, args, name, binds); + IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); // Get the legacy TE pass list - auto pass_list = CreatePassList(simple_mode, true); + Array pass_list = CreatePassList(simple_mode, true); return LowerWithPassList(mod, pass_list); } @@ -373,7 +373,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert({kv.first, kv.second}); } } - return LowerSchedule(sch, args, name, c_binds, simple_mode); + return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); }); std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, From 55e90cf276b0e76c3fd3b5ef4209fe40e3ac0eae Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 8 Jun 2021 11:54:37 -0700 Subject: [PATCH 45/52] remove 2nd get binds impl --- src/driver/driver_api.cc | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d5635724177c..56e288df2ea3 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -93,22 +93,6 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list) { - *out_binds = binds; - - for (const auto& x : args) { - if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); - out_binds->Set(x, buf); - out_arg_list->push_back(buf); - } else { - out_arg_list->push_back((*out_binds)[x]); - } - } -} - void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { @@ -135,6 +119,16 @@ void GetBinds(const Array& args, bool compact, } } +void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + GetBinds(ref_args, compact, binds, out_binds, out_arg_list); +} + TVM_REGISTER_GLOBAL("driver.get_binds") .set_body_typed([](const Array& args, bool compact, const Map& binds) { From d3e18c2b7a3f935f2c7e1bba793c43dcca651a91 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 8 Jun 2021 13:16:06 -0700 Subject: [PATCH 46/52] fix lint --- src/relay/backend/compile_engine.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7043d82e5d02..3d5f0f4d4f3c 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -769,7 +769,7 @@ class CompileEngineImpl : public CompileEngineNode { std::unordered_map binds; cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); - + value->cached_func = CachedFunc(cache_node); return value; } From 6e5a2ff7fa03f81bc93813911ce2a9836c1e7018 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 8 Jun 2021 13:21:29 -0700 Subject: [PATCH 47/52] Update src/driver/driver_api.cc Co-authored-by: Tristan Konolige --- src/driver/driver_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 56e288df2ea3..22e4bfc52796 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -114,7 +114,7 @@ void GetBinds(const Array& args, bool compact, } else { LOG(FATAL) << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " - << "but got a " << typeid(x).name(); + << "but got a " << x->GetTypeKey(); } } } From 4933a5e5d790e3c66d8ae167e282c131ff4652dd Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 8 Jun 2021 17:16:54 -0700 Subject: [PATCH 48/52] clang format --- src/relay/backend/compile_engine.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 3d5f0f4d4f3c..9ecf1f886019 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -767,8 +767,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = - tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; From b42fb19ca5349a9fb0b9470e46feedece2af2a4b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 8 Jun 2021 17:53:47 -0700 Subject: [PATCH 49/52] update doc --- include/tvm/driver/driver_api.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index e8004bbb6b0f..418d532fdd5f 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -45,7 +45,7 @@ namespace tvm { /*! - * \brief Build an IRModule given an input IRModule + * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) * \param mod The IRmodule to lower * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. @@ -53,7 +53,8 @@ namespace tvm { TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); /*! - * \brief Build an IRModule given a primfunc and name + * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list + * defined in CreatePassList) * \param func The PrimFunc to lower * \param name The name of the lowered function. * \param simple_mode Disables the loop partition pass. Defaults to false. From e462dab970ce5f499be3760e0ffbd873b9a094e0 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 8 Jun 2021 18:56:32 -0700 Subject: [PATCH 50/52] fix typo --- python/tvm/driver/build_module.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index d2c54d4804da..ef6fdbfc1733 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -123,13 +123,13 @@ def lower( The result IRModule """ if isinstance(inp, IRModule): - return ffi.lower_module(input, simple_mode) + return ffi.lower_module(inp, simple_mode) if isinstance(inp, PrimFunc): - return ffi.lower_primfunc(input, name, simple_mode) + return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): - return ffi.lower_schedule(input, args, name, binds, simple_mode) + return ffi.lower_schedule(inp, args, name, binds, simple_mode) raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inputs) + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp) ) From d067a67052b6350ac342b834bc440a8809da02d4 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 9 Jun 2021 11:11:11 -0700 Subject: [PATCH 51/52] black --- python/tvm/driver/build_module.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index ef6fdbfc1733..a4df63f225b2 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -128,9 +128,7 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp) - ) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def _build_for_device(input_mod, target, target_host): From 794c606f6e7345167e7f06d1c04186fc00c9b638 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 10 Jun 2021 11:55:46 -0700 Subject: [PATCH 52/52] fix pass ctx --- src/relay/backend/compile_engine.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9ecf1f886019..5146c90f3bac 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -763,9 +763,6 @@ class CompileEngineImpl : public CompileEngineNode { all_args.push_back(arg); } // lower the function - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - std::unordered_map binds; cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds);