diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index 0053d309d5a9..5c11439e1635 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -101,21 +101,7 @@ # Skip running in CI environment IS_IN_CI = os.getenv("CI", "") == "true" if not IS_IN_CI: - with target: - mod = tvm.ir.transform.Sequential( - [ - # Convert BatchNorm into a sequence of simpler ops for fusion - relax.transform.DecomposeOpsForInference(), - # Canonicalize the bindings - relax.transform.CanonicalizeBindings(), - # Run default optimization pipeline - relax.get_pipeline("zero"), - # Tune the model and store the log to database - relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS), - # Apply the database - relax.transform.MetaScheduleApplyDatabase(work_dir), - ] - )(mod) + mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod) # Only show the main function mod["main"].show() diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 38242ff4d2d3..582f5111aaf5 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -21,6 +21,7 @@ as it is or serves as a basis to do further composition. """ # pylint: disable=unused-argument +from typing import Union import tvm from tvm import meta_schedule as ms @@ -104,10 +105,48 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I return _pipeline +def static_shape_tuning_pipeline( + total_trials: int, + target: Union[str, tvm.target.Target], + work_dir: str = "tuning_logs", +): + """Tune the static shape model and store the log to database. + + Parameters + ---------- + total_trials : int + Total number of trials to run. + + target : Union[str, tvm.target.Target] + The target device to tune the model. + + work_dir : str + The directory to store the tuning logs. + """ + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + with tvm.target.Target(target): + mod = tvm.transform.Sequential( + [ + transform.DecomposeOpsForInference(), + transform.CanonicalizeBindings(), + zero_pipeline(), + transform.MetaScheduleTuneIRMod({}, work_dir, total_trials), + transform.MetaScheduleApplyDatabase(work_dir), + ] + )(mod) + + return mod + + return _pipeline + + # global map of pre-built pipelines PIPELINE_MAP = { "zero": zero_pipeline, "default_build": default_build_pipeline, + "static_shape_tuning": static_shape_tuning_pipeline, } diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 95649f331f33..3330d4098734 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1020,14 +1020,13 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor ---------- param_tuple_name: Optional[str] - The name of the tuple parameter. If unspecified, defaults to + The name of the tuple parameter. If unspecified, defaults to "model_params". Returns ------- ret : tvm.transform.Pass - The registered pass for lifting transformation of parameters. - + The registered pass for bundling model parameters. """ return _ffi_api.BundleModelParams(param_tuple_name) # type: ignore