From 351dfa48f0529c1eb179bfce4b694691e9cd4af7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 8 Oct 2021 08:42:06 -0500 Subject: [PATCH 1/3] [TE] Light refactoring of TE -> TIR paths. - Added ScheduleToPrimFunc, extracting out common behavior in ScheduleToModule and auto_scheduler's feature extraction. - Added `tvm.driver.build_module.schedule_to_module`, to avoid needing to 4-line boilerplate needed to do so. Also makes deviations from the usual path (e.g. `debug_keep_trivial_loop`) much more explicit. --- python/tvm/autotvm/feature.py | 8 +-- python/tvm/driver/__init__.py | 2 +- python/tvm/driver/build_module.py | 36 +++++++++++ .../backend/contrib/ethosu/tir/compiler.py | 16 ++--- src/auto_scheduler/feature.cc | 21 ++---- src/driver/driver_api.cc | 60 ++++++++++++----- src/driver/utils.h | 64 +++++++++++++++++++ tests/python/integration/test_reduce.py | 7 +- tests/python/unittest/test_te_schedule_ops.py | 20 ++---- .../test_tir_transform_inject_copy_intrin.py | 15 ++--- .../test_tir_transform_make_packed_api.py | 8 +-- ...merge_dynamic_shared_memory_allocations.py | 12 ++-- .../test_tir_transform_narrow_datatype.py | 8 +-- .../test_tir_transform_storage_flatten.py | 11 +--- .../test_tir_transform_storage_rewrite.py | 40 +++--------- 15 files changed, 195 insertions(+), 133 deletions(-) create mode 100644 src/driver/utils.h diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index 8d2591dce50b..f8576c1a33a5 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm.target import Target -from tvm.te import schedule from tvm.driver import build_module @@ -39,12 +38,11 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds, True) - func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) + context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop": True}) + with context: + func = build_module.schedule_to_primfunc(sch, args, binds=binds) mod = tvm.IRModule.from_expr(func._move()) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py index 75e94cc91c83..3ef297990cbe 100644 --- a/python/tvm/driver/__init__.py +++ b/python/tvm/driver/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Namespace for driver APIs""" -from .build_module import lower, build +from .build_module import lower, build, schedule_to_primfunc diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 429b3e1727cc..bb2b8f068796 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -67,6 +67,11 @@ def schedule_to_module( binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, ) -> IRModule: """According to the given schedule, form a function. + + This is a low-level function intended for testing purposes, and + does not apply any optimization passes. In general, `tvm.lower` + and `tvm.build` should be used instead. + Parameters ---------- sch : tvm.te.schedule.Schedule @@ -84,6 +89,37 @@ def schedule_to_module( return ffi.schedule_to_module(sch, args, name, binds) +def schedule_to_primfunc( + sch: schedule.Schedule, + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: + """According to the given schedule, form a function. + + This is a low-level function intended for testing purposes, and + does not apply any optimization passes. In general, `tvm.lower` + and `tvm.build` should be used instead. + + Parameters + ---------- + sch : tvm.te.schedule.Schedule + The given scheduler to form the raw body + args : list of Buffer or Tensor or Var + The argument lists to the function. + name : str + The name of result function, default name is "main" + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + The binds information + + Returns + ------- + The body formed according to the given schedule + + """ + return ffi.schedule_to_primfunc(sch, args, name, binds) + + def lower( inp: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 3283e0515c72..54dc3be95c44 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -19,7 +19,7 @@ import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator -from tvm.driver.build_module import get_binds +from tvm.driver.build_module import schedule_to_primfunc from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants from .scheduler import schedule @@ -64,22 +64,18 @@ def lower_ethosu(sch, args, const_dict, name="main"): "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": {"auto_max_depth": -1}, + "tir.debug_keep_trivial_loop": True, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) - compact = tvm.te.schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, None) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) with tvm.transform.PassContext(config=curr_cfg): + func = schedule_to_primfunc(sch, args, name) + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index be78bc4aa9f9..73854a33e5ca 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -41,16 +42,10 @@ #include #include +#include "../driver/utils.h" #include "search_policy/utils.h" #include "utils.h" -namespace tvm { -// import the function from driver_api.cc -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list); -} // namespace tvm - namespace tvm { namespace auto_scheduler { @@ -1269,22 +1264,14 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); sch = sch.normalize_for_feature_extraction(); - auto bounds = te::InferBound(sch); try { - auto stmt = te::ScheduleOps(sch, bounds, false); - Map out_binds; - Array out_arg_list; - bool compact = te::VerifyCompactBuffer(stmt); const std::string& name = "main"; GlobalVar global_var(name); - // Copied from driver_api.cc::lower auto pass_ctx = tvm::transform::PassContext::Current(); - GetBinds(tensors, compact, std::unordered_map(), &out_binds, - &out_arg_list); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + tir::PrimFunc f = ScheduleToPrimFunc(sch, Array{tensors.begin(), tensors.end()}, + name, std::unordered_map()); bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); bool disable_vectorize = diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e659421c23c4..0c0d4872c5f9 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -34,6 +34,8 @@ #include #include +#include "utils.h" + namespace tvm { // Register build pipeline related options @@ -44,6 +46,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -287,31 +290,18 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { return mod; } +// Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { - // Convert te schedule to IRModule - Array out_arg_list; - transform::PassContext pass_ctx = transform::PassContext::Current(); - sch = sch.normalize(); - // Before TIR transformation. - Map bounds = te::InferBound(sch); - tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); - bool compact = te::VerifyCompactBuffer(stmt); - - Map out_binds; - GetBinds(args, compact, binds, &out_binds, &out_arg_list); - - // Build the function - // At this point binds is only te::Tensors - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + tir::PrimFunc f = ScheduleToPrimFunc(sch, args, name, binds); // Mark this schedule as being converted from an TE schedule. Makes sure that // the correct TE passes are run. f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true)); + transform::PassContext pass_ctx = transform::PassContext::Current(); bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { @@ -325,7 +315,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != nullptr) { + if (binds.defined()) { for (auto kv : binds) { c_binds.insert({kv.first, kv.second}); } @@ -334,6 +324,42 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") return mod; }); +tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool debug_keep_trivial_loop = + pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); + + // Before TIR transformation. + tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); + bool compact = te::VerifyCompactBuffer(stmt); + + Map out_binds; + Array out_arg_list; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // Build the function, converting from te::Tensor to tir::Buffer + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + return f; +} + +TVM_REGISTER_GLOBAL("driver.schedule_to_primfunc") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.defined()) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + tir::PrimFunc func = ScheduleToPrimFunc(std::move(sch), args, name, c_binds); + return func; + }); + IRModule LowerModule(IRModule mod, bool simple_mode) { Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(std::move(mod), pass_list); diff --git a/src/driver/utils.h b/src/driver/utils.h new file mode 100644 index 000000000000..0aa86802a092 --- /dev/null +++ b/src/driver/utils.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file utils.h + * \brief Internal utilities for manipulating TE schedules. + */ + +#ifndef TVM_DRIVER_UTILS_H_ +#define TVM_DRIVER_UTILS_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { + +/*! + * \brief Create an PrimFunc out of a TE Schedule. + * + * Generated PrimFunc expresses reads/writes using + * BufferLoad/BufferStore, with all Tensors and + * ProducerLoad/ProducerStore having been replaced. + * + * Assumes that the schedule has already been normalized, either with + * `te::Schedule::normalize` or + * `te::Schedule::normalize_for_feature_extraction`. + * + * Does not apply lowering passes. If you want + * to apply lowering passes as well, use LowerSchedule. + * + * \param sch The schedule + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \return The result module. + */ +tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds); + +} // namespace tvm + +#endif // TVM_DRIVER_UTILS_H_ diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index ca097734a9eb..dc01ffd81bbf 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np import tvm from tvm import te, topi -import numpy as np +from tvm.driver.build_module import schedule_to_primfunc import tvm.testing import tvm.topi.testing @@ -532,9 +533,7 @@ def test_reduce_storage_reuse(): target = tvm.target.Target("cuda") def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + func = schedule_to_primfunc(sch, args) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) return tvm.transform.Sequential( diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index bc4bc4f56e19..31707a46dbde 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import te -import numpy as np +from tvm.driver.build_module import schedule_to_primfunc def test_schedule0(): @@ -26,10 +28,7 @@ def test_schedule0(): A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) + func = schedule_to_primfunc(s, [A, A1]) assert isinstance(func, tvm.tir.PrimFunc) @@ -42,11 +41,8 @@ def test_schedule1(): s = te.create_schedule(A1.op) xo, xi = s[A1].split(A1.op.axis[0], 8) s[A1].pragma(xo, "auto_unroll_max_step", 10) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) + func = schedule_to_primfunc(s, [A, A1]) assert isinstance(func, tvm.tir.PrimFunc) @@ -60,10 +56,8 @@ def test_schedule2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + + func = schedule_to_primfunc(s, [A, A2]) assert isinstance(func, tvm.tir.PrimFunc) diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 86bf87d5fa85..ea858e9c8258 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm.driver.build_module import schedule_to_primfunc def test_copy2d(): @@ -53,10 +54,7 @@ def test_copy_pad(): ) s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + func = schedule_to_primfunc(s, [A, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -77,10 +75,7 @@ def test_single_point_test(): B = te.compute((1,), lambda i: A[i], name="B") s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + func = schedule_to_primfunc(s, [A, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -105,10 +100,8 @@ def test_copy_pad_split(): xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + func = schedule_to_primfunc(s, [A, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 15f994069abd..00ce78193dc5 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy + import tvm from tvm import te -import numpy +from tvm.driver.build_module import schedule_to_primfunc def test_makeapi(): @@ -27,9 +29,7 @@ def test_makeapi(): C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = te.create_schedule(C.op) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None) + func = schedule_to_primfunc(s, [n, A, B, C]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Apply( diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9c511f1de6b9..f070edefe3bb 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -14,19 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np + +import tvm import tvm.testing +from tvm import te +from tvm.driver.build_module import schedule_to_primfunc from tvm.topi.math import cast def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + func = schedule_to_primfunc(sch, args) mod = tvm.IRModule.from_expr(func) return tvm.transform.Sequential( [ diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index cb8968cfc880..0dfa25611453 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te -from tvm import relay +from tvm import te, relay +from tvm.driver.build_module import schedule_to_primfunc from tvm.tir import const @@ -39,10 +39,8 @@ def lower_sch(sch, args, target_bits): else: raise ValueError("args must be Tensor, Buffer or Var") sch = sch.normalize() - bounds = te.schedule.InferBound(sch) - stmt = te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + func = schedule_to_primfunc(sch, args) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 37223493a8b5..c08a21fee6ed 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_primfunc from tvm.script import tir as T from tvm.relay import GlobalVar @@ -30,13 +31,10 @@ def test_flatten2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A") A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2") - func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab, A2: A2b}) + func = schedule_to_primfunc(s, [Ab, A2b], binds={A: Ab, A2: A2b}) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -70,11 +68,8 @@ def test_flatten_storage_align(): s = te.create_schedule(A2.op) s[A1].storage_align(A1.op.axis[0], 2, 1) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + func = schedule_to_primfunc(s, [A, A2]) mod = tvm.IRModule.from_expr(func) mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..22222405f009 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_primfunc def test_storage_share(): @@ -28,11 +29,7 @@ def test_storage_share(): B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t) s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + func = schedule_to_primfunc(s, [A, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -169,11 +166,7 @@ def test_inplace_rule(): AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA") B = te.compute((m,), lambda i: AA[i] + 1, name="B") s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + func = schedule_to_primfunc(s, [A, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -206,10 +199,8 @@ def test_storage_combine(): s = te.create_schedule(B.op) for S in stages[:-1]: s[S].set_scope("global:tag") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + + func = schedule_to_primfunc(s, [A, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -238,9 +229,7 @@ def test_storage_combine_with_vectorization(): BB = s.cache_read(B, "global:tag", readers=[C]) CC = s.cache_write(C, "global:tag") s[CC].vectorize(s[CC].op.axis[0]) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None) + func = schedule_to_primfunc(s, [A, B, C]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) @@ -285,10 +274,7 @@ def test_storage_share_gpu(): s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx) s[A[2 * t + 1]].set_scope("shared") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) + func = schedule_to_primfunc(s, [A[0], A[-1]]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -418,11 +404,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): A0L = s.cache_read(A0, scope_tb, [A2]) A1L = s.cache_read(A1, scope_tb, [A2]) A2L = s.cache_read(A2, scope_tb, [B]) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) + func = schedule_to_primfunc(s, [A, B, C, D]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -511,11 +493,7 @@ def test_inplace_rule3(): s[B10].compute_inline() s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5, B], stmt, None) + func = schedule_to_primfunc(s, [B0, B1, B2, B3, B4, B5, B]) mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) From 83723af418eaa82dbc77b34bdcf68e4b759bc42d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Oct 2021 09:46:51 -0500 Subject: [PATCH 2/3] Removed schedule_to_primfunc, replaced usage with schedule_to_module. --- python/tvm/autotvm/feature.py | 4 +- python/tvm/driver/__init__.py | 2 +- python/tvm/driver/build_module.py | 31 --------- .../backend/contrib/ethosu/tir/compiler.py | 7 +- src/auto_scheduler/feature.cc | 23 +++---- src/driver/driver_api.cc | 60 +++++++---------- src/driver/utils.h | 64 ------------------- tests/python/integration/test_reduce.py | 5 +- tests/python/unittest/test_te_schedule_ops.py | 14 ++-- .../test_tir_transform_inject_copy_intrin.py | 11 ++-- .../test_tir_transform_make_packed_api.py | 5 +- ...merge_dynamic_shared_memory_allocations.py | 5 +- .../test_tir_transform_narrow_datatype.py | 5 +- .../test_tir_transform_storage_flatten.py | 8 +-- .../test_tir_transform_storage_rewrite.py | 23 +++---- 15 files changed, 67 insertions(+), 200 deletions(-) delete mode 100644 src/driver/utils.h diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index f8576c1a33a5..f73c65fbd1d8 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -42,8 +42,8 @@ def ana_lower(sch, args, binds=None, simple_mode=True): # Phase 0 context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop": True}) with context: - func = build_module.schedule_to_primfunc(sch, args, binds=binds) - mod = tvm.IRModule.from_expr(func._move()) + mod = build_module.schedule_to_module(sch, args, binds=binds) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py index 3ef297990cbe..75e94cc91c83 100644 --- a/python/tvm/driver/__init__.py +++ b/python/tvm/driver/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Namespace for driver APIs""" -from .build_module import lower, build, schedule_to_primfunc +from .build_module import lower, build diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index bb2b8f068796..29fff775150f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -89,37 +89,6 @@ def schedule_to_module( return ffi.schedule_to_module(sch, args, name, binds) -def schedule_to_primfunc( - sch: schedule.Schedule, - args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, - name: str = "main", - binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, -) -> IRModule: - """According to the given schedule, form a function. - - This is a low-level function intended for testing purposes, and - does not apply any optimization passes. In general, `tvm.lower` - and `tvm.build` should be used instead. - - Parameters - ---------- - sch : tvm.te.schedule.Schedule - The given scheduler to form the raw body - args : list of Buffer or Tensor or Var - The argument lists to the function. - name : str - The name of result function, default name is "main" - binds : dict of :any:`Tensor` to :any:`Buffer`, optional - The binds information - - Returns - ------- - The body formed according to the given schedule - - """ - return ffi.schedule_to_primfunc(sch, args, name, binds) - - def lower( inp: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 54dc3be95c44..c792ade06643 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -19,7 +19,7 @@ import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants from .scheduler import schedule @@ -64,6 +64,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": {"auto_max_depth": -1}, + "tir.noalias": True, "tir.debug_keep_trivial_loop": True, } # Merge two configs @@ -72,9 +73,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): sch = sch.normalize() with tvm.transform.PassContext(config=curr_cfg): - func = schedule_to_primfunc(sch, args, name) - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) + mod = schedule_to_module(sch, args, name) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 73854a33e5ca..aaf7d48b10c5 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -42,7 +42,6 @@ #include #include -#include "../driver/utils.h" #include "search_policy/utils.h" #include "utils.h" @@ -1263,27 +1262,25 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i Array tensors; std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + + // When inlining, replace const matrices with const values. + // Produces wrong IR, but good enough for feature extraction, and + // can improve the speed of feature extraction/search. Must be + // called before ScheduleToModule to have an effect. sch = sch.normalize_for_feature_extraction(); try { const std::string& name = "main"; - GlobalVar global_var(name); - auto pass_ctx = tvm::transform::PassContext::Current(); - tir::PrimFunc f = ScheduleToPrimFunc(sch, Array{tensors.begin(), tensors.end()}, - name, std::unordered_map()); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, + std::unordered_map()); + bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - auto mod = IRModule(Map({{global_var, f}})); - if (IsGPUTask(task)) { auto pass_list = Array(); // Phase 0 @@ -1310,9 +1307,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(global_var); - ICHECK(it != mod->functions.end()); - const auto& prim_func = (*it).second.as(); + PrimFunc prim_func = Downcast(mod->Lookup(name)); GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, feature); } catch (Error& e) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0c0d4872c5f9..aa5ad6c767b0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -34,8 +34,6 @@ #include #include -#include "utils.h" - namespace tvm { // Register build pipeline related options @@ -290,6 +288,28 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { return mod; } +tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool debug_keep_trivial_loop = + pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); + + // Before TIR transformation. + tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); + bool compact = te::VerifyCompactBuffer(stmt); + + Map out_binds; + Array out_arg_list; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // Build the function, converting from te::Tensor to tir::Buffer + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + return f; +} + // Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { @@ -324,42 +344,6 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") return mod; }); -tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, - const std::string& name, - const std::unordered_map& binds) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - bool debug_keep_trivial_loop = - pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); - - // Before TIR transformation. - tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); - bool compact = te::VerifyCompactBuffer(stmt); - - Map out_binds; - Array out_arg_list; - GetBinds(args, compact, binds, &out_binds, &out_arg_list); - - // Build the function, converting from te::Tensor to tir::Buffer - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - - return f; -} - -TVM_REGISTER_GLOBAL("driver.schedule_to_primfunc") - .set_body_typed([](te::Schedule sch, const Array& args, const String& name, - const Map& binds) { - std::unordered_map c_binds; - // Check to make sure binds is not null before doing the conversion; - if (binds.defined()) { - for (auto kv : binds) { - c_binds.insert({kv.first, kv.second}); - } - } - tir::PrimFunc func = ScheduleToPrimFunc(std::move(sch), args, name, c_binds); - return func; - }); - IRModule LowerModule(IRModule mod, bool simple_mode) { Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(std::move(mod), pass_list); diff --git a/src/driver/utils.h b/src/driver/utils.h deleted file mode 100644 index 0aa86802a092..000000000000 --- a/src/driver/utils.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * - * \file utils.h - * \brief Internal utilities for manipulating TE schedules. - */ - -#ifndef TVM_DRIVER_UTILS_H_ -#define TVM_DRIVER_UTILS_H_ - -#include -#include -#include - -#include -#include - -namespace tvm { - -/*! - * \brief Create an PrimFunc out of a TE Schedule. - * - * Generated PrimFunc expresses reads/writes using - * BufferLoad/BufferStore, with all Tensors and - * ProducerLoad/ProducerStore having been replaced. - * - * Assumes that the schedule has already been normalized, either with - * `te::Schedule::normalize` or - * `te::Schedule::normalize_for_feature_extraction`. - * - * Does not apply lowering passes. If you want - * to apply lowering passes as well, use LowerSchedule. - * - * \param sch The schedule - * \param args The arguments to the function. - * \param name The name of the lowered function. - * \param binds Buffer assignments. - * \return The result module. - */ -tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, - const std::string& name, - const std::unordered_map& binds); - -} // namespace tvm - -#endif // TVM_DRIVER_UTILS_H_ diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index dc01ffd81bbf..a40164ded941 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -19,7 +19,7 @@ import tvm from tvm import te, topi -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module import tvm.testing import tvm.topi.testing @@ -533,8 +533,7 @@ def test_reduce_storage_reuse(): target = tvm.target.Target("cuda") def run_passes(sch, args): - func = schedule_to_primfunc(sch, args) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) return tvm.transform.Sequential( [ diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 31707a46dbde..ca3ab3aade98 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -18,7 +18,7 @@ import tvm from tvm import te -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module def test_schedule0(): @@ -28,8 +28,8 @@ def test_schedule0(): A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) - func = schedule_to_primfunc(s, [A, A1]) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule1(): @@ -42,8 +42,8 @@ def test_schedule1(): xo, xi = s[A1].split(A1.op.axis[0], 8) s[A1].pragma(xo, "auto_unroll_max_step", 10) - func = schedule_to_primfunc(s, [A, A1]) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule2(): @@ -57,8 +57,8 @@ def test_schedule2(): xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - func = schedule_to_primfunc(s, [A, A2]) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A2]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule_scan(): diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index ea858e9c8258..aa0448c3c682 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module def test_copy2d(): @@ -54,8 +54,7 @@ def test_copy_pad(): ) s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - func = schedule_to_primfunc(s, [A, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -75,8 +74,7 @@ def test_single_point_test(): B = te.compute((1,), lambda i: A[i], name="B") s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - func = schedule_to_primfunc(s, [A, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -101,8 +99,7 @@ def test_copy_pad_split(): s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") - func = schedule_to_primfunc(s, [A, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 00ce78193dc5..1ab6bdaad90a 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -18,7 +18,7 @@ import tvm from tvm import te -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module def test_makeapi(): @@ -29,8 +29,7 @@ def test_makeapi(): C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = te.create_schedule(C.op) - func = schedule_to_primfunc(s, [n, A, B, C]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [n, A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Apply( lambda f: f.with_attr( diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index f070edefe3bb..cc78b84f9b4e 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -19,13 +19,12 @@ import tvm import tvm.testing from tvm import te -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module from tvm.topi.math import cast def run_passes(sch, args): - func = schedule_to_primfunc(sch, args) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) return tvm.transform.Sequential( [ tvm.tir.transform.StorageFlatten(64), diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 0dfa25611453..b5620d748d8a 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import te, relay -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module from tvm.tir import const @@ -40,8 +40,7 @@ def lower_sch(sch, args, target_bits): raise ValueError("args must be Tensor, Buffer or Var") sch = sch.normalize() - func = schedule_to_primfunc(sch, args) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.StorageFlatten(64)(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index c08a21fee6ed..a51e926155d3 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import te -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module from tvm.script import tir as T from tvm.relay import GlobalVar @@ -34,8 +34,7 @@ def test_flatten2(): Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A") A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2") - func = schedule_to_primfunc(s, [Ab, A2b], binds={A: Ab, A2: A2b}) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b}) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -69,8 +68,7 @@ def test_flatten_storage_align(): s = te.create_schedule(A2.op) s[A1].storage_align(A1.op.axis[0], 2, 1) - func = schedule_to_primfunc(s, [A, A2]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, A2]) mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 22222405f009..5a91788283d6 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import te -from tvm.driver.build_module import schedule_to_primfunc +from tvm.driver.build_module import schedule_to_module def test_storage_share(): @@ -29,8 +29,7 @@ def test_storage_share(): B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t) s = te.create_schedule(B.op) - func = schedule_to_primfunc(s, [A, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -166,8 +165,7 @@ def test_inplace_rule(): AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA") B = te.compute((m,), lambda i: AA[i] + 1, name="B") s = te.create_schedule(B.op) - func = schedule_to_primfunc(s, [A, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -200,8 +198,7 @@ def test_storage_combine(): for S in stages[:-1]: s[S].set_scope("global:tag") - func = schedule_to_primfunc(s, [A, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -229,8 +226,7 @@ def test_storage_combine_with_vectorization(): BB = s.cache_read(B, "global:tag", readers=[C]) CC = s.cache_write(C, "global:tag") s[CC].vectorize(s[CC].op.axis[0]) - func = schedule_to_primfunc(s, [A, B, C]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -274,8 +270,7 @@ def test_storage_share_gpu(): s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx) s[A[2 * t + 1]].set_scope("shared") - func = schedule_to_primfunc(s, [A[0], A[-1]]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A[0], A[-1]]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -404,8 +399,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): A0L = s.cache_read(A0, scope_tb, [A2]) A1L = s.cache_read(A1, scope_tb, [A2]) A2L = s.cache_read(A2, scope_tb, [B]) - func = schedule_to_primfunc(s, [A, B, C, D]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C, D]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -493,8 +487,7 @@ def test_inplace_rule3(): s[B10].compute_inline() s = s.normalize() - func = schedule_to_primfunc(s, [B0, B1, B2, B3, B4, B5, B]) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) From d71d00b44a1ca73c0b82745e7d16376a5e57a30e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Oct 2021 16:43:16 -0500 Subject: [PATCH 3/3] Returned C++ function ScheduleToPrimfunc to be inside ScheduleToModule. --- src/driver/driver_api.cc | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index aa5ad6c767b0..2d57d6e30b45 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -288,9 +288,11 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { return mod; } -tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, - const std::string& name, - const std::unordered_map& binds) { +// Convert te schedule to IRModule +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { + sch = sch.normalize(); + transform::PassContext pass_ctx = transform::PassContext::Current(); bool debug_keep_trivial_loop = pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); @@ -307,21 +309,10 @@ tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array& args, tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - return f; -} - -// Convert te schedule to IRModule -IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { - sch = sch.normalize(); - - tir::PrimFunc f = ScheduleToPrimFunc(sch, args, name, binds); - // Mark this schedule as being converted from an TE schedule. Makes sure that // the correct TE passes are run. f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true)); - transform::PassContext pass_ctx = transform::PassContext::Current(); bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) {