diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index 8d2591dce50b..f73c65fbd1d8 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm.target import Target -from tvm.te import schedule from tvm.driver import build_module @@ -39,13 +38,12 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds, True) - func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func._move()) + context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop": True}) + with context: + mod = build_module.schedule_to_module(sch, args, binds=binds) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 429b3e1727cc..29fff775150f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -67,6 +67,11 @@ def schedule_to_module( binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, ) -> IRModule: """According to the given schedule, form a function. + + This is a low-level function intended for testing purposes, and + does not apply any optimization passes. In general, `tvm.lower` + and `tvm.build` should be used instead. + Parameters ---------- sch : tvm.te.schedule.Schedule diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 3283e0515c72..c792ade06643 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -19,7 +19,7 @@ import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator -from tvm.driver.build_module import get_binds +from tvm.driver.build_module import schedule_to_module from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants from .scheduler import schedule @@ -64,22 +64,17 @@ def lower_ethosu(sch, args, const_dict, name="main"): "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": {"auto_max_depth": -1}, + "tir.noalias": True, + "tir.debug_keep_trivial_loop": True, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) - compact = tvm.te.schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, None) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) with tvm.transform.PassContext(config=curr_cfg): + mod = schedule_to_module(sch, args, name) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index be78bc4aa9f9..aaf7d48b10c5 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -44,13 +45,6 @@ #include "search_policy/utils.h" #include "utils.h" -namespace tvm { -// import the function from driver_api.cc -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list); -} // namespace tvm - namespace tvm { namespace auto_scheduler { @@ -1268,35 +1262,25 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i Array tensors; std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + + // When inlining, replace const matrices with const values. + // Produces wrong IR, but good enough for feature extraction, and + // can improve the speed of feature extraction/search. Must be + // called before ScheduleToModule to have an effect. sch = sch.normalize_for_feature_extraction(); - auto bounds = te::InferBound(sch); try { - auto stmt = te::ScheduleOps(sch, bounds, false); - Map out_binds; - Array out_arg_list; - bool compact = te::VerifyCompactBuffer(stmt); const std::string& name = "main"; - GlobalVar global_var(name); - - // Copied from driver_api.cc::lower auto pass_ctx = tvm::transform::PassContext::Current(); - GetBinds(tensors, compact, std::unordered_map(), &out_binds, - &out_arg_list); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, + std::unordered_map()); + bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - auto mod = IRModule(Map({{global_var, f}})); - if (IsGPUTask(task)) { auto pass_list = Array(); // Phase 0 @@ -1323,9 +1307,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(global_var); - ICHECK(it != mod->functions.end()); - const auto& prim_func = (*it).second.as(); + PrimFunc prim_func = Downcast(mod->Lookup(name)); GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, feature); } catch (Error& e) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e659421c23c4..2d57d6e30b45 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -287,24 +288,24 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { return mod; } +// Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { - // Convert te schedule to IRModule - Array out_arg_list; - transform::PassContext pass_ctx = transform::PassContext::Current(); - sch = sch.normalize(); + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool debug_keep_trivial_loop = + pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); + // Before TIR transformation. - Map bounds = te::InferBound(sch); - tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); + tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); bool compact = te::VerifyCompactBuffer(stmt); Map out_binds; + Array out_arg_list; GetBinds(args, compact, binds, &out_binds, &out_arg_list); - // Build the function - // At this point binds is only te::Tensors + // Build the function, converting from te::Tensor to tir::Buffer tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); @@ -325,7 +326,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != nullptr) { + if (binds.defined()) { for (auto kv : binds) { c_binds.insert({kv.first, kv.second}); } diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index ca097734a9eb..a40164ded941 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np import tvm from tvm import te, topi -import numpy as np +from tvm.driver.build_module import schedule_to_module import tvm.testing import tvm.topi.testing @@ -532,10 +533,7 @@ def test_reduce_storage_reuse(): target = tvm.target.Target("cuda") def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) return tvm.transform.Sequential( [ diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index bc4bc4f56e19..ca3ab3aade98 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import te -import numpy as np +from tvm.driver.build_module import schedule_to_module def test_schedule0(): @@ -26,11 +28,8 @@ def test_schedule0(): A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule1(): @@ -42,12 +41,9 @@ def test_schedule1(): s = te.create_schedule(A1.op) xo, xi = s[A1].split(A1.op.axis[0], 8) s[A1].pragma(xo, "auto_unroll_max_step", 10) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule2(): @@ -60,11 +56,9 @@ def test_schedule2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + + mod = schedule_to_module(s, [A, A2]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule_scan(): diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 86bf87d5fa85..aa0448c3c682 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm.driver.build_module import schedule_to_module def test_copy2d(): @@ -53,11 +54,7 @@ def test_copy_pad(): ) s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -77,11 +74,7 @@ def test_single_point_test(): B = te.compute((1,), lambda i: A[i], name="B") s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -105,11 +98,8 @@ def test_copy_pad_split(): xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 15f994069abd..1ab6bdaad90a 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy + import tvm from tvm import te -import numpy +from tvm.driver.build_module import schedule_to_module def test_makeapi(): @@ -27,10 +29,7 @@ def test_makeapi(): C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = te.create_schedule(C.op) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [n, A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Apply( lambda f: f.with_attr( diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9c511f1de6b9..cc78b84f9b4e 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -14,20 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np + +import tvm import tvm.testing +from tvm import te +from tvm.driver.build_module import schedule_to_module from tvm.topi.math import cast def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) return tvm.transform.Sequential( [ tvm.tir.transform.StorageFlatten(64), diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index cb8968cfc880..b5620d748d8a 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te -from tvm import relay +from tvm import te, relay +from tvm.driver.build_module import schedule_to_module from tvm.tir import const @@ -39,11 +39,8 @@ def lower_sch(sch, args, target_bits): else: raise ValueError("args must be Tensor, Buffer or Var") sch = sch.normalize() - bounds = te.schedule.InferBound(sch) - stmt = te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.StorageFlatten(64)(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 37223493a8b5..a51e926155d3 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module from tvm.script import tir as T from tvm.relay import GlobalVar @@ -30,14 +31,10 @@ def test_flatten2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A") A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2") - func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab, A2: A2b}) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b}) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -70,12 +67,8 @@ def test_flatten_storage_align(): s = te.create_schedule(A2.op) s[A1].storage_align(A1.op.axis[0], 2, 1) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, A2]) mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..5a91788283d6 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module def test_storage_share(): @@ -28,12 +29,7 @@ def test_storage_share(): B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t) s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -169,12 +165,7 @@ def test_inplace_rule(): AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA") B = te.compute((m,), lambda i: AA[i] + 1, name="B") s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -206,11 +197,8 @@ def test_storage_combine(): s = te.create_schedule(B.op) for S in stages[:-1]: s[S].set_scope("global:tag") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -238,10 +226,7 @@ def test_storage_combine_with_vectorization(): BB = s.cache_read(B, "global:tag", readers=[C]) CC = s.cache_write(C, "global:tag") s[CC].vectorize(s[CC].op.axis[0]) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -285,11 +270,7 @@ def test_storage_share_gpu(): s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx) s[A[2 * t + 1]].set_scope("shared") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A[0], A[-1]]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -418,12 +399,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): A0L = s.cache_read(A0, scope_tb, [A2]) A1L = s.cache_read(A1, scope_tb, [A2]) A2L = s.cache_read(A2, scope_tb, [B]) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C, D]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -511,12 +487,7 @@ def test_inplace_rule3(): s[B10].compute_inline() s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod)