diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 09cc7747bc55..2dfaec00b7ba 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -62,6 +62,7 @@ def __exit__(self, ptype, value, trace): assert self._old_scope BuildConfig.current = self._old_scope + BuildConfig.current = BuildConfig() def build_config(**kwargs): @@ -102,7 +103,8 @@ def build_config(**kwargs): Whether split the loop containing double buffer so that the buffer fetching won't contain condition. - add_lower_pass: list of function(Stmt->Stmt), default=None + add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None + phase contains an integer on which optimization pass we apply the pass. Additional lowering passes to be applied before make_api. Returns @@ -193,11 +195,19 @@ def lower(sch, """ binds, arg_list = get_binds(args, binds) cfg = BuildConfig.current + add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] + lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] + lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] + lower_phase2 = [x[1] for x in add_lower_pass if x[0] > 1] # normalize schedule first sch = sch.normalize() + # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.InjectPrefetch(stmt) + for f in lower_phase0: + stmt = f(stmt) + # Phase 1 stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.CanonicalSimplify(stmt) if not simple_mode: @@ -211,13 +221,15 @@ def lower(sch, cfg.auto_unroll_max_step, cfg.auto_unroll_min_depth, cfg.unroll_explicit) - if cfg.add_lower_pass: - for f in cfg.add_lower_pass: - stmt = f(stmt) + for f in lower_phase1: + stmt = f(stmt) + # Phase 2 stmt = ir_pass.Simplify(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt) + for f in lower_phase2: + stmt = f(stmt) if simple_mode: return stmt return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)