-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TE] Light refactoring of TE -> TIR paths. #9263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| import tvm | ||
| from tvm import relay | ||
| from tvm.relay.expr_functor import ExprMutator | ||
| from tvm.driver.build_module import get_binds | ||
| from tvm.driver.build_module import schedule_to_module | ||
|
|
||
| from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants | ||
| from .scheduler import schedule | ||
|
|
@@ -64,22 +64,17 @@ def lower_ethosu(sch, args, const_dict, name="main"): | |
| "no_unroll_loop_with_extent_one": True, | ||
| }, | ||
| "tir.UnrollLoop": {"auto_max_depth": -1}, | ||
| "tir.noalias": True, | ||
| "tir.debug_keep_trivial_loop": True, | ||
| } | ||
| # Merge two configs | ||
| curr_cfg = {**curr_cfg, **tir_compiler_cfg} | ||
|
|
||
| sch = sch.normalize() | ||
| bounds = tvm.te.schedule.InferBound(sch) | ||
| stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) | ||
|
|
||
| compact = tvm.te.schedule.VerifyCompactBuffer(stmt) | ||
| binds, arg_list = get_binds(args, compact, None) | ||
| func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) | ||
|
|
||
| func = func.with_attr("global_symbol", name) | ||
| func = func.with_attr("tir.noalias", True) | ||
| mod = tvm.IRModule({name: func}) | ||
| with tvm.transform.PassContext(config=curr_cfg): | ||
| mod = schedule_to_module(sch, args, name) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like that you've removed this python logic, but why call schedule_to_primfunc and then wrap the func in an IRModule? Why not use schedule_to_module?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Due to silliness and lack of pattern recognition on my part, since I was only looking for cases that could be replaced with |
||
| mod = tvm.tir.transform.Simplify()(mod) | ||
| mod = tvm.tir.transform.StorageFlatten(64)(mod) | ||
| mod = tvm.tir.transform.UnrollLoop()(mod) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); | |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); | ||
|
|
||
| using runtime::PackedFunc; | ||
| using runtime::TVMArgs; | ||
|
|
@@ -287,24 +288,24 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { | |
| return mod; | ||
| } | ||
|
|
||
| // Convert te schedule to IRModule | ||
| IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, | ||
| const std::unordered_map<te::Tensor, tir::Buffer>& binds) { | ||
| // Convert te schedule to IRModule | ||
| Array<ObjectRef> out_arg_list; | ||
| transform::PassContext pass_ctx = transform::PassContext::Current(); | ||
|
|
||
| sch = sch.normalize(); | ||
|
|
||
| transform::PassContext pass_ctx = transform::PassContext::Current(); | ||
| bool debug_keep_trivial_loop = | ||
| pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop", Bool(false)).value(); | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This tells Edit: And also used in the autotvm feature extraction |
||
| // Before TIR transformation. | ||
| Map<tir::IterVar, Range> bounds = te::InferBound(sch); | ||
| tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); | ||
| tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); | ||
| bool compact = te::VerifyCompactBuffer(stmt); | ||
|
|
||
| Map<te::Tensor, tir::Buffer> out_binds; | ||
| Array<ObjectRef> out_arg_list; | ||
| GetBinds(args, compact, binds, &out_binds, &out_arg_list); | ||
|
|
||
| // Build the function | ||
| // At this point binds is only te::Tensors | ||
| // Build the function, converting from te::Tensor to tir::Buffer | ||
| tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); | ||
| f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); | ||
|
|
||
|
|
@@ -325,7 +326,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") | |
| const Map<te::Tensor, tir::Buffer>& binds) { | ||
| std::unordered_map<te::Tensor, tir::Buffer> c_binds; | ||
| // Check to make sure binds is not null before doing the conversion; | ||
| if (binds.get() != nullptr) { | ||
| if (binds.defined()) { | ||
| for (auto kv : binds) { | ||
| c_binds.insert({kv.first, kv.second}); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need
tir.debug_keep_trivial_loop?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This deliberately keeps the loop iterators in-place, even if they have
extent=1, rather than the default behavior of replacing trivial iterators with a Let statement. As a result, the itervars can be examined for optimization parameters (e.g. in xgboost).Longer term, I'd prefer having it always generate the loops with a lowering pass to identify/simplify the trivial loops, but that's a later item.