From 323ed082ebdb51ada8041ca6d304679c767fa9e2 Mon Sep 17 00:00:00 2001 From: sung Date: Thu, 16 Feb 2023 16:58:14 -0800 Subject: [PATCH] Add TuningAPI and MetaSchedule tuning pass --- CMakeLists.txt | 1 + include/tvm/ir/transform.h | 54 +- include/tvm/relax/transform.h | 22 +- include/tvm/relax/tuning_api.h | 396 +++++++++ include/tvm/relay/transform.h | 2 +- include/tvm/tir/transform.h | 2 +- python/tvm/ir/transform.py | 95 ++- python/tvm/meta_schedule/__init__.py | 1 + python/tvm/meta_schedule/relax_integration.py | 352 ++++++++ python/tvm/meta_schedule/tir_integration.py | 89 ++ python/tvm/meta_schedule/tune_context.py | 3 +- python/tvm/relax/transform/transform.py | 69 +- .../relax/transform/tuning_api/__init__.py | 22 + .../relax/transform/tuning_api/_ffi_api.py | 19 + .../relax/transform/tuning_api/database.py | 273 ++++++ .../transform/tuning_api/default_functions.py | 306 +++++++ .../relax/transform/tuning_api/primitives.py | 419 ++++++++++ python/tvm/tir/transform/function_pass.py | 3 +- src/ir/transform.cc | 84 +- src/relax/backend/task_extraction.cc | 114 +++ src/relax/ir/transform.cc | 8 +- src/relax/transform/meta_schedule.cc | 171 ++++ src/relax/transform/tuning_api/database.cc | 350 ++++++++ src/relax/transform/tuning_api/primitives.cc | 273 ++++++ src/relay/ir/transform.cc | 4 +- src/relay/transforms/type_infer.cc | 2 +- src/tir/ir/transform.cc | 4 +- .../test_transform_meta_schedule_tuning.py | 115 +++ tests/python/relax/test_tuning_api.py | 781 ++++++++++++++++++ 29 files changed, 3987 insertions(+), 47 deletions(-) create mode 100644 include/tvm/relax/tuning_api.h create mode 100644 python/tvm/meta_schedule/relax_integration.py create mode 100644 python/tvm/relax/transform/tuning_api/__init__.py create mode 100644 python/tvm/relax/transform/tuning_api/_ffi_api.py create mode 100644 python/tvm/relax/transform/tuning_api/database.py create mode 100644 python/tvm/relax/transform/tuning_api/default_functions.py create mode 100644 python/tvm/relax/transform/tuning_api/primitives.py create mode 100644 src/relax/backend/task_extraction.cc create mode 100644 src/relax/transform/meta_schedule.cc create mode 100644 src/relax/transform/tuning_api/database.cc create mode 100644 src/relax/transform/tuning_api/primitives.cc create mode 100644 tests/python/relax/test_transform_meta_schedule_tuning.py create mode 100644 tests/python/relax/test_tuning_api.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a28a9acde9c..0f154ca577c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -294,6 +294,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/analysis/*.cc src/relax/transform/*.cc src/relax/backend/vm/*.cc + src/relax/backend/task_extraction.cc src/relax/utils.cc ) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 473e6291685d..ff54a6b5eacd 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -32,18 +32,18 @@ * - Reducing the effort required to implement new passes for compiler * developers, etc. * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * Similar to LLVM's pass manager, we designed the Relay/Relax pass manager to work * different granularity, i.e. module level, function level, and even sequential * passe that contains a host of passes. * * However, we also extend the functionality of the traditional pass manager * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay/Relax pass * manager performs the IRModule -> IRModule transformation. All * different types of passes, including the sequential-level pass object, are * essentially pass objects. This design, therefore, effectively provides users * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with + * means to ease the development and testing of Relay/Relax passes. For example, with * the pass manager, external users will be able to have custom passes correctly * scheduled without having to modify a single handcrafted pass order. * @@ -90,7 +90,16 @@ class PassContextNode : public Object { /*! \brief A list of pass instrument implementations. */ Array instruments; - + // TODO(@sunggg): Fix dependency issue in the header file and correct the types + // e.g., relax::trace, relax::database in tvm/relax/tuning_api.h + /*! \brief Trace stack for relax pass infra. */ + mutable Array trace_stack; + /*! \brief List of passes to be traced. If not defined, make every pass traceable. */ + Optional> make_traceable; + /*! \brief Number of evaluations conducted in the pass pipeline. */ + mutable int num_evals{0}; + /*! \brief Database for tuning API. */ + Optional tuning_api_database; PassContextNode() = default; /*! @@ -130,7 +139,27 @@ class PassContextNode : public Object { v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); + v->Visit("trace_stack", &trace_stack); + v->Visit("make_traceable", &make_traceable); + v->Visit("num_evals", &num_evals); + v->Visit("tuning_api_daatabase", &tuning_api_database); + } + + Array GetTraceStack() { return trace_stack; } + void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); } + void PopTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + trace_stack.pop_back(); } + int GetTraceStackSize() { return trace_stack.size(); } + ObjectRef GetCurrentTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + return trace_stack.back(); + } + void SetNumEvals(int _num_evals) { num_evals = _num_evals; } + void IncNumEvals(int _num_evals) { num_evals += _num_evals; } + + Optional GetTuningAPIDatabase() { return tuning_api_database; } static constexpr const char* _type_key = "transform.PassContext"; static constexpr bool _type_has_method_sequal_reduce = false; @@ -287,6 +316,9 @@ class PassInfoNode : public Object { /*! \brief The name of an optimization/analysis pass. */ String name; + /*! \brief Boolean that tells whether this pass will be traced or not. */ + bool traceable; + /*! \brief The passes that are required to perform the current pass. */ Array required; @@ -296,6 +328,7 @@ class PassInfoNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); + v->Visit("traceable", &traceable); } static constexpr const char* _type_key = "transform.PassInfo"; @@ -314,8 +347,9 @@ class PassInfo : public ObjectRef { * \param opt_level The optimization level * \param name Name of the pass. * \param required The passes that are required to perform the current pass. + * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required); + TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -323,7 +357,7 @@ class PassInfo : public ObjectRef { /*! * \brief PassNode is the base type of differnt types of optimization passes. * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. + * at different granularity of Relay/Relax nodes. */ class PassNode : public Object { public: @@ -396,7 +430,7 @@ class Pass : public ObjectRef { }; /*! - * \brief The SequentialNode contains a set of passes that transform Relay + * \brief The SequentialNode contains a set of passes that transform Relay/Relax * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly @@ -489,9 +523,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass -CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, Array required); +TVM_DLL Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, int opt_level, + String name, Array required, bool traceable = false); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index dab062588a82..e9f63ee9dbc9 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -44,12 +44,13 @@ using DataflowBlock = tvm::relax::DataflowBlock; * \param opt_level The optimization level of the function pass. * \param name The name of the function pass. * \param required The list of the passes that the function pass is dependent on. + * \param traceable Boolean variable whether the dataflowblock pass is traceable. * * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Create a dataflowblock pass. @@ -58,12 +59,13 @@ TVM_DLL Pass CreateFunctionPass( * \param opt_level The optimization level of the dataflowblock pass. * \param name The name of the dataflowblock pass. * \param required The list of the passes that the dataflowblock pass is dependent on. + * \param traceable Boolean variable whether the dataflowblock pass is traceable. * * \return The created dataflowblock pass. */ TVM_DLL Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Transform all dataflow structure to non-dataflow version. @@ -93,6 +95,22 @@ TVM_DLL Pass CallTIRRewrite(); */ TVM_DLL Pass RewriteDataflowReshape(); +/*! + * \brief Bind params of function of the module to constant tensors. + * + * \param func_name The name of the function to bind parameters. + * \param params The parameters to bind. + * + * \return The Pass. + */ +TVM_DLL Pass BindParams(String func_name, Map params); + +/*! + * \brief Fold constant expressions. + * + * \return The Pass. + */ +TVM_DLL Pass FoldConstant(); /*! * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. * diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h new file mode 100644 index 000000000000..b6224a6d6d9e --- /dev/null +++ b/include/tvm/relax/tuning_api.h @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/tuning_api.h + * \brief Relax Tuning Pass APIs. + */ +#ifndef TVM_RELAX_TUNING_API_H_ +#define TVM_RELAX_TUNING_API_H_ +#include +#include +#include + +#include +namespace tvm { +namespace relax { + +/*! \brief Helper function to unpack arguments in the array as parameters for the given packed + * function. */ +TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, + const Array& args) { + size_t num_args = args.size(); + std::vector values(num_args); + std::vector codes(num_args); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + const ObjectRef* ptr = args.template as()->begin(); + for (size_t i = 0; i < num_args; ++i) { + setter(i, *(ptr + i)); + } + + TVMRetValue rv; + f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv); + return rv; +} + +/*! \brief Choice manages a set of keys for transformation and constraint functions. */ +class ChoiceNode : public runtime::Object { + public: + /*! \brief ffi key for transformation function. */ + String transform_func_key; + /*! \brief ffi key for constraint function. */ + String constr_func_key; + Array transform_func_args; + Array constr_func_args; + + /*! \brief The default destructor. */ + virtual ~ChoiceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("transform_func_key", &transform_func_key); + v->Visit("transform_func_args", &transform_func_args); + v->Visit("constr_func_key", &constr_func_key); + v->Visit("constr_func_args", &constr_func_args); + } + + /*! \brief Getter for constr_func. */ + const runtime::PackedFunc GetConstrFunc() { + const auto* constr_func = tvm::runtime::Registry::Get(constr_func_key); + ICHECK(constr_func != nullptr) << "constr_func_key is not registered: " << constr_func_key; + return *constr_func; + } + + /*! \brief Getter for transform_func. */ + const runtime::PackedFunc GetTransformFunc() { + auto* transform_func = tvm::runtime::Registry::Get(transform_func_key); + ICHECK(transform_func != nullptr) + << "transform_func_key is not registered: " << transform_func_key; + return *transform_func; + } + + /*! \brief Perform constr_func. */ + bool CheckConstr(const IRModule& mod) { + Array args(constr_func_args); + args.insert(args.begin(), mod); + return CallPackedWithArgsInArray(GetConstrFunc(), args); + } + + /*! \brief Perform transform_func. */ + IRModule ApplyTransformFunc(IRModule mod) { + // Apply transformation when constraint is satisfied. + if (CheckConstr(mod)) { + Array args(transform_func_args); + args.insert(args.begin(), GetRef(mod.CopyOnWrite())); + return CallPackedWithArgsInArray(GetTransformFunc(), args); + } + return mod; + } + + /*! + * \brief Serialize Choice as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Choice"; + TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object); +}; + +/*! \brief Managed reference to ChoiceNode */ +class Choice : public runtime::ObjectRef { + public: + TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args); + /*! \brief Deserialize JSON-style object into Choice */ + TVM_DLL static Choice FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode); +}; + +/*! \brief Knob manages a set of valid choices for an optimization. */ +class KnobNode : public runtime::Object { + public: + /*! \brief Name of the knob. */ + String name; + /*! \brief Decision space. */ + Map choices; + + /*! \brief The default destructor. */ + virtual ~KnobNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("choices", &choices); + } + + /*! \brief Check if a decision is valid. */ + bool IsValidDecision(String decision) { return choices.count(decision) > 0; } + + /*! \brief Apply decision if the constraint is satisfied. + Otherwise, return the original IRModule. + */ + IRModule Apply(IRModule mod, String decision) { + ICHECK(IsValidDecision(decision)) << "Invalid choice for this knob: " << decision; + return choices[decision]->ApplyTransformFunc(mod); + } + + /*! + * \brief Serialize Knob as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Knob"; + TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object); +}; + +/*! \brief Managed reference to KnobNode */ +class Knob : public runtime::ObjectRef { + public: + TVM_DLL explicit Knob(String name, Map choices); + /*! \brief Deserialize JSON-style object into Knob */ + TVM_DLL static Knob FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode); +}; + +/*! \brief Trace manages history of optimization decisions. */ +class TraceNode : public runtime::Object { + public: + /*! \brief Input IRModule. */ + IRModule in_mod; + /*! \brief Output IRModule. */ + mutable IRModule out_mod; + // TODO(sunggg): can we move knobs and decisions into private? + /*! \brief Knobs that are applied so far. */ + Array knobs; + /*! \brief Decisions made for the knobs. */ + Array decisions; + /*! \brief Performance of out_mod. */ + mutable double perf = -1; + /*! \brief Length of the decision history. */ + mutable int size = 0; + /*! \brief The default destructor. */ + virtual ~TraceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("in_mod", &in_mod); + v->Visit("out_mod", &out_mod); + v->Visit("knobs", &knobs); + v->Visit("decisions", &decisions); + v->Visit("perf", &perf); + v->Visit("size", &size); + } + + /*! \brief Verify current decision history. */ + bool Verify() const { + if (knobs.size() != decisions.size()) return false; + int n = knobs.size(); + for (int i = 0; i < n; i++) { + if (!knobs[i]->IsValidDecision(decisions[i])) return false; + } + return true; + } + + /*! \brief Add a knob and its decision to the current trace. */ + IRModule Add(Knob knob, String decision) { + out_mod = knob->Apply(out_mod, decision); + knobs.push_back(knob); + decisions.push_back(decision); + // perf number should be initialized after new decision is applied. + perf = -1; + // increment history size. + size++; + return out_mod; + } + + /*! + * \brief Serialize Trace as a JSON-style object + * \param include_in_mod Boolean config to include input IRModule in the output. + * \return The JSON-style object + */ + ObjectRef AsJSON(bool include_in_mod = true) const; + + /*! \brief Set the performance. */ + void SetPerf(double _perf) { perf = _perf; } + /*! \brief Set output module. */ + void SetOutMod(IRModule mod_) { out_mod = mod_; } + + static constexpr const char* _type_key = "relax.tuning_api.Trace"; + TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object); +}; + +/*! \brief Managed reference to TraceNode */ +class Trace : public runtime::ObjectRef { + public: + /*! \brief Default constructor. Creating an empty trace. */ + Trace(); + /*! + * \brief Constructor. Creating a trace from existing knobs and their decisions + * \param in_mod Input IRModule + * \param knobs The knobs used + * \param decisions The decisions made in sampling + */ + TVM_DLL explicit Trace(IRModule in_mod, Array knobs, Array decisions); + /*! \brief Deserialize JSON-style object into Trace */ + TVM_DLL static Trace FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode); +}; + +/*! \brief The class of tuning records. */ +class TuningRecordNode : public runtime::Object { + public: + /*! \brief The trace tuned. */ + Trace trace; + /*! \brief The measurement record in seconds. */ + Optional> run_secs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("trace", &trace); + v->Visit("run_secs", &run_secs); + } + + static constexpr const char* _type_key = "relax.tuning_api.TuningRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + + /*! + * \brief Export the tuning record to a JSON string. + * \param include_irmod Boolean config to include IRModules in the output. + * \return JSON object + */ + ObjectRef AsJSON(bool include_irmod = false) const; +}; + +/*! + * \brief The managed reference of TuningRecordNode. + * \sa TuningRecordNode + */ +class TuningRecord : public runtime::ObjectRef { + public: + /*! + \brief Constructor of a tuning record. + \param trace The trace of the tuning record. + \param run_secs The running time of the tuning record. + */ + TVM_DLL explicit TuningRecord(Trace trace, Optional> run_secs); + /*! + * \brief Create a tuning record from a json object. + * \param json_obj The json object. + * \return The tuning record created. + */ + TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); +}; + +/*! \brief The equality check for Workload */ +struct WorkloadEqual { + bool operator()(const meta_schedule::Workload& a, const meta_schedule::Workload& b) const { + return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + } +}; + +/* \brief The abstract interface of database. */ +class DatabaseNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~DatabaseNode() = default; + /*! + * \brief Check if the database has the given workload. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + virtual bool HasWorkload(const IRModule& mod) = 0; + /*! + * \brief Check if the database has a measurement record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the measurement record for given workload and target pair. + */ + virtual bool HasMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target) = 0; + /*! + * \brief Check if the database has a tuning record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the tuning record for the given workload and target pair. + */ + virtual bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) = 0; + /*! + * \brief Look up or add workload to the database if missing. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + virtual meta_schedule::Workload CommitWorkload(const IRModule& mod) = 0; + /*! + * \brief Add a measurement record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Measurement record to be added. + */ + virtual void CommitMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target, const Array& record) = 0; + /*! + * \brief Add a tuning record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Tuning record to be added. + */ + virtual void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) = 0; + /*! + * \brief Get the top K tuning records of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + virtual Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) = 0; + /*! + * \brief Get the measurement record of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \return Measurement. + */ + virtual Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) = 0; + + static constexpr const char* _type_key = "relax.tuning_api.Database"; + TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); +}; + +/*! + * \brief Managed reference to DatabaseNode. + * \sa DatabaseNode + */ +class Database : public runtime::ObjectRef { + public: + /*! + * \brief Create a default database that uses JSON file for tuning records. + * \param path_workload The path to the workload table. + * \param path_tuning_record The path to the tuning record table. + * \param path_measurement_record The path to the measurement_record table. + * \param allow_missing Whether to create new file when the given path is not found. + */ + TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TUNING_API_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 43a0f89d95c1..256f1a64dd87 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -60,7 +60,7 @@ using Sequential = tvm::transform::Sequential; */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! \brief Remove let-bound expressions which do not effect the program result. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index be7589b04bf5..cbfb2b1ade19 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,7 +56,7 @@ using tvm::transform::Sequential; */ TVM_DLL Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Inject prefetch instructions into stmt. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 17995bfa7850..21f5d41d862a 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -45,8 +45,10 @@ class PassInfo(tvm.runtime.Object): The list of passes that are required by a certain pass. """ - def __init__(self, opt_level, name, required=None): - self.__init_handle_by_constructor__(_ffi_transform_api.PassInfo, opt_level, name, required) + def __init__(self, opt_level, name, required=None, traceable=False): + self.__init_handle_by_constructor__( + _ffi_transform_api.PassInfo, opt_level, name, required, traceable + ) @tvm._ffi.register_object("transform.PassContext") @@ -70,6 +72,20 @@ class PassContext(tvm.runtime.Object): config : Optional[Dict[str, Object]] Additional configurations for specific passes. + + trace: Optional[relax.tuning.Trace] + Initial trace for trace mode. + + trace_stack: Optional[List[relax.tuning_api.Trace]] + Initial trace stack for trace mode. + + make_traceable: Optional[List[str]] + List of passes to make traceable. + + num_evals: int + initial number of evaluations conducted in the pipeline. + + tuning_api_database: Optional[relax.tuning_api.JSONDatabase] """ def __init__( @@ -79,6 +95,11 @@ def __init__( disabled_pass=None, instruments=None, config=None, + trace=None, + trace_stack=None, + make_traceable=None, + num_evals=0, + tuning_api_database=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -92,9 +113,25 @@ def __init__( if not isinstance(instruments, (list, tuple)): raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") + # Convert to Map + # TODO(sunggg): Replace this to Set equivalent if exists + make_traceable = {name: True for name in make_traceable} if make_traceable else None + + if not trace_stack: + trace_stack = [trace] if trace else [] + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config + _ffi_transform_api.PassContext, + opt_level, + required, + disabled, + instruments, + config, + trace_stack, + make_traceable, + num_evals, + tuning_api_database, ) def __enter__(self): @@ -131,6 +168,47 @@ def list_configs(): """ return _ffi_transform_api.ListConfigs() + def push_trace(self, trace): + """Push a trace into the stack.""" + return _ffi_transform_api.PushTrace(self, trace) + + def pop_trace(self, return_current=True): + """Pop a topmost trace from the stack. + Returns + ------- + Trace : Optional[relax.tuning.Trace] + """ + if return_current: + cur_trace = self.get_current_trace() + _ffi_transform_api.PopTrace(self) + return cur_trace + + return _ffi_transform_api.PopTrace(self) + + def get_trace_stack(self): + """Get the current trace stack.""" + return _ffi_transform_api.GetTraceStack(self) + + def get_trace_stack_size(self): + """Get the size of current stack.""" + return _ffi_transform_api.GetTraceStackSize(self) + + def get_current_trace(self): + """Get the trace on the top of the stack.""" + return _ffi_transform_api.GetCurrentTrace(self) + + def set_num_evals(self, num: int): + """Set the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.SetNumEvals(self, num) + + def inc_num_evals(self, num: int): + """Increment the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.IncNumEvals(self, num) + + def get_tuning_api_database(self): + """Get tuning api database.""" + return _ffi_transform_api.GetTuningAPIDatabase(self) + @tvm._ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): @@ -199,7 +277,7 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ - def __init__(self, passes=None, opt_level=0, name="sequential", required=None): + def __init__(self, passes=None, opt_level=0, name="sequential", required=None, traceable=False): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") @@ -209,7 +287,7 @@ def __init__(self, passes=None, opt_level=0, name="sequential", required=None): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__( - _ffi_transform_api.Sequential, passes, opt_level, name, required + _ffi_transform_api.Sequential, passes, opt_level, name, required, traceable ) @@ -245,7 +323,7 @@ def __getattr__(self, name): return PyModulePass -def module_pass(pass_func=None, opt_level=None, name=None, required=None): +def module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False): """Decorate a module pass. This function returns a callback when pass_func is provided. @@ -270,6 +348,9 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): required : Optional[List[str]] The list of passes that the module pass is dependent on. + traceable: Boolean + Boolean variable whether the module pass is traceable + Returns ------- create_module_pass : Union[Callable, ModulePass] @@ -337,7 +418,7 @@ def transform(mod, ctx): def create_module_pass(pass_arg): """Internal function that creates a module pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_module_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 30a4fc6d9467..21a11ff9e84d 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -25,6 +25,7 @@ mutator, postproc, relay_integration, + relax_integration, runner, schedule, schedule_rule, diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py new file mode 100644 index 000000000000..a82d8996858b --- /dev/null +++ b/python/tvm/meta_schedule/relax_integration.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + +from tvm._ffi import get_global_func, register_func +from tvm.ir import IRModule +from tvm.ir.transform import PassContext +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.tir.expr import IntImm + +from .builder import Builder +from .cost_model import CostModel +from .database import Database +from .extracted_task import ExtractedTask +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext +from .utils import fork_seed + +if TYPE_CHECKING: + from tvm import relax + +_extract_task_func = get_global_func( # pylint: disable=invalid-name + "relax.backend.MetaScheduleExtractTask", + allow_missing=False, +) + + +def extract_tasks( + mod: Union[IRModule, "relax.Function"], + target: Target, + params: Optional[Dict[str, NDArray]] = None, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this module + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import Function as RelaxFunc + from tvm.relax.transform import BindParams + + # pylint: enable=import-outside-toplevel + if isinstance(mod, RelaxFunc): + mod = IRModule({"main": mod}) + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + return list(_extract_task_func(mod, target)) + + +def extracted_tasks_to_tune_contexts( + extracted_tasks: List[ExtractedTask], + work_dir: str, + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Tuple[List[TuneContext], List[float]]: + """Convert ExtractedTask to TuneContext. + + Parameters + ---------- + tasks : List[ExtractedTask] + The tasks to be converted + work_dir : str + The working directory to store logs and databases + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use in multi-threaded search algorithm. + seed : Optional[int] + The random seed to use. + + Returns + ------- + tasks : List[TuneContext] + The converted tasks + task_weights : List[float] + The weights of the tasks + """ + tasks: List[TuneContext] = [] + task_weights: List[float] = [] + for task, logger, rand_state in zip( + extracted_tasks, + get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), + fork_seed(seed, n=len(extracted_tasks)), + ): + tasks.append( + TuneContext( + mod=task.dispatched[0], + target=task.target, + space_generator=space, + search_strategy=strategy, + task_name=task.task_name, + logger=logger, + rand_state=rand_state, + num_threads=num_threads, + ).clone() + ) + task_weights.append(task.weight) + return tasks, task_weights + + +def tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + database : Database + The database that contains the tuning records + """ + tasks, task_weights = extracted_tasks_to_tune_contexts( + extracted_tasks=extract_tasks(mod, target, params), + work_dir=work_dir, + space=space, + strategy=strategy, + seed=seed, + ) + return tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + ) + + +@register_func("tvm.meta_schedule.tune_relax") +def _tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + + tune_relax( + mod, + params, + target, + work_dir, + max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + +def compile_relax( + database: Database, + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], +) -> "relax.vm.Executable": + """Compile a relax program with a MetaSchedule database. + + Parameters + ---------- + database : Database + The database to use + mod : IRModule + The Relax program to be compiled + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + + Returns + ------- + lib : relax.vm.Executable + The built runtime module or vm Executable for the given relax workload. + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase + from tvm.relax.vm import build as relax_build + + # pylint: enable=import-outside-toplevel + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + + with target, database, PassContext(opt_level=3): + relax_mod = MetaScheduleApplyDatabase()(mod) + relax_ex = relax_build(relax_mod, target=target) + return relax_ex diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index f3d505c28b0e..d5f5ee86e0b8 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -22,7 +22,9 @@ # isort: on from tvm import ir, tir +from tvm._ffi import register_func from tvm.target import Target +from tvm.tir.expr import IntImm from .builder import Builder from .cost_model import CostModel @@ -128,6 +130,93 @@ def tune_tir( ) +@register_func("tvm.meta_schedule.tune_tir") +def _tune_tir( + mod: Union[ir.IRModule, tir.PrimFunc], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + task_name: str = "main", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a TIR program. + + Parameters + ---------- + mod : Union[ir.IRModule, tir.PrimFunc] + The TIR function to tune. + target : Union[str, Target] + The target to tune for. + work_dir : str + The working directory. + max_trials_global : int + The maximum number of trials to run globally. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. + space : SpaceGenerator.SpaceGeneratorType + The space generator. + strategy : SearchStrategy.SearchStrategyType + The search strategy. + task_name : str + The name of the task. + num_tuning_cores : Union[Literal["physical", "logical"], int] + The number of CPU cores to use during tuning. + seed : Optional[int] + The seed for the random number generator. + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + tune_tir( + mod, + target, + work_dir, + max_trials_global, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + task_name=task_name, + num_tuning_cores=num_tuning_cores, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + def compile_tir( database: Database, mod: Union[ir.IRModule, tir.PrimFunc], diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 38a46ebe757e..8c4f4ce864ab 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -24,7 +24,7 @@ # isort: on from tvm import IRModule -from tvm._ffi import register_object +from tvm._ffi import register_object, register_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule @@ -41,6 +41,7 @@ from .space_generator import SpaceGenerator +@register_func("tvm.meta_schedule.normalize_mod") def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 745a26a4dac4..c0ac180ff165 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -21,8 +21,8 @@ import types from typing import Callable, Dict, Union, Optional, List import numpy as np # type: ignore - import tvm.ir +from tvm.runtime import NDArray from . import _ffi_api @@ -218,6 +218,60 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore +def MetaScheduleApplyDatabase( + work_dir: Optional[str] = None, +) -> tvm.ir.transform.Pass: + """Apply the best schedule from tuning database. + work_dir : Optional[str] + work directory to deduce default database if database is not provided + (it will be ignored when an user passes database) + Returns + ------- + ret : tvm.transform.Pass + The registered pass + """ + return _ffi_api.MetaScheduleApplyDatabase(work_dir) # type: ignore + + +def MetaScheduleTuneTIR( + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune TIR with MetaSchedule. + Parameters + ---------- + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global) # type: ignore + + +def MetaScheduleTuneIRMod( + params: Dict[str, NDArray], + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune Relax IRModule with MetaSchedule. + Parameters + ---------- + params: Dict[str, NDArray] + model params + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" @@ -255,6 +309,7 @@ def function_pass( opt_level=None, name=None, required=None, + traceable=False, ) -> Union[Callable, FunctionPass]: """Decorate a function pass. @@ -277,6 +332,9 @@ def function_pass( required : Optional[List[str]] The list of passes that the function pass is dependent on. + traceable: Boolean + Boolean variable whether the function pass is traceable + Returns ------- create_function_pass : Union[Callable, FunctionPass] @@ -350,7 +408,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = tvm.transform.PassInfo(opt_level, fname, required) + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): @@ -395,7 +453,7 @@ def __getattr__(self, name): def dataflowblock_pass( - pass_func=None, opt_level=None, name=None, required=None + pass_func=None, opt_level=None, name=None, required=None, traceable=False ) -> Union[Callable, DataflowBlockPass]: """Decorate a dataflowblock pass. @@ -418,6 +476,9 @@ def dataflowblock_pass( required : Optional[List[str]] The list of passes that the dataflowblock pass is dependent on. + traceable: Boolean + Boolean variable whether the dataflowblock pass is traceable + Returns ------- create_dataflowblock_pass : Union[Callable, DataflowBlockPass] @@ -499,7 +560,7 @@ def transform(block, mod, ctx): def create_dataflowblock_pass(pass_arg): """Internal function that creates a dataflowblock pass""" fname = name if name else pass_arg.__name__ - info = tvm.transform.PassInfo(opt_level, fname, required) + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_dataflowblock_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/python/tvm/relax/transform/tuning_api/__init__.py b/python/tvm/relax/transform/tuning_api/__init__.py new file mode 100644 index 000000000000..6c39d5c5359e --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax Tunign Pass API""" + +from .primitives import * +from .default_functions import * +from .database import * diff --git a/python/tvm/relax/transform/tuning_api/_ffi_api.py b/python/tvm/relax/transform/tuning_api/_ffi_api.py new file mode 100644 index 000000000000..f31522d02595 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for relax.tuning_api""" +import tvm._ffi + +tvm._ffi._init_api("relax.tuning_api", __name__) diff --git a/python/tvm/relax/transform/tuning_api/database.py b/python/tvm/relax/transform/tuning_api/database.py new file mode 100644 index 000000000000..9477e142bad4 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/database.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import List, Optional +import logging + +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.meta_schedule.utils import _json_de_tvm +from tvm.meta_schedule.database import Workload +from tvm.tir.schedule.trace import JSON_TYPE +from tvm.target import Target +from tvm._ffi import register_object +from .primitives import Trace +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.TuningRecord") +class TuningRecord(Object): + """The class of tuning records. + + Parameters + ---------- + trace : tvm.relax.transform.tuning_api.Trace + The trace of the tuning record. + run_secs : Optional[List[float]] + The run-time of the tuning record. + """ + + trace: Trace + run_secs: Optional[List[float]] + + def __init__( # type: ignore # pylint: disable=too-many-arguments + self, + trace: Trace, + run_secs: Optional[List[float]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member + trace, + run_secs, + ) + + def as_json(self, include_irmod: bool = False) -> JSON_TYPE: + """Export the tuning record to a JSON string. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self, include_irmod)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "TuningRecord": + """Create a tuning record from a json object. + + Parameters + ---------- + json_obj : JSON_TYPE + The json object to parse. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.TuningRecordFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.Database") +class Database(Object): + """The abstract database interface.""" + + def has_workload(self, mod: IRModule) -> bool: + """Check if the database has the given workload. + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + + Returns + ------- + result : bool + Whether the given workload is committed. + """ + return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def has_measurement_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a measurement record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the measurement record. + """ + return _ffi_api.DatabaseHasMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def has_tuning_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a tuning record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the tuning record. + """ + return _ffi_api.DatabaseHasTuningRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def commit_workload(self, mod: IRModule) -> Workload: + """Commit a workload to the database if missing. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for or added. + + Returns + ------- + workload : Workload + The workload corresponding to the given IRModule. + """ + return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def commit_measurement_record( + self, workload: Workload, target: Target, run_secs: List[float] + ) -> None: + """Commit a measurement record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + run_secs : Optional[List[float]] + The measurement record to add. + """ + _ffi_api.DatabaseCommitMeasurementRecord(self, workload, target, run_secs) # type: ignore # pylint: disable=no-member + + def commit_tuning_record( + self, workload: Workload, target: Target, record: TuningRecord + ) -> None: + """Commit a tuning record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + record : TuningRecord + The tuning record to add. + """ + _ffi_api.DatabaseCommitTuningRecord(self, workload, target, record) # type: ignore # pylint: disable=no-member + + def get_measurement_record(self, workload: Workload, target: Target) -> Optional[List[float]]: + """Get the measurement record of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + measurement_record : Optional[List[float]] + Measurement record if exists. + """ + return _ffi_api.DatabaseGetMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def get_top_k(self, workload: Workload, target: Target, top_k: int) -> List[TuningRecord]: + """Get the top K tuning records of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + top_k : int + The number of top records to get. + + Returns + ------- + top_k_records : List[TuningRecord] + The top K records. + """ + return _ffi_api.DatabaseGetTopK(self, workload, target, top_k) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.JSONDatabase") +class JSONDatabase(Database): + """The class of JSON database. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + Manages pairs of + path_measurement_record : str + The path to the path_measurement_record table. + Manages pairs of + """ + + path_workload: str + path_tuning_record: str + path_measurement_record: str + + def __init__( + self, + path_workload: str, + path_tuning_record: str, + path_measurement_record: str, + allow_missing: bool = True, + ) -> None: + """Constructor. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + path_measurement_record : str + The path to the path_measurement_record table. + allow_missing : bool + Whether to create new file when the given path is not found. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member + path_workload, + path_tuning_record, + path_measurement_record, + allow_missing, + ) diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py new file mode 100644 index 000000000000..b72b2f30ee2b --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import Dict, List, Optional +import sys +import itertools +import logging +import numpy as np # type: ignore + +import tvm +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext, Pass +from tvm import meta_schedule +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) +from tvm._ffi.registry import register_func +from .primitives import Knob, Trace + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + +# Default transform func that returns original IRModule. +@tvm.register_func("relax.tuning_api.Choice.default_transform_func") +def default_transform_func(mod): + return mod + + +# Default constraint func that always returns true. +@tvm.register_func("relax.tuning_api.Choice.default_constr_func") +def default_constr_func(mod: IRModule) -> bool: # pylint: disable=unused-argument + return True + + +@register_func("relax.tuning_api.default_generate_candidate") +def default_generate_candidate( + knobs: List[Knob], trace: Trace, eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to generate the search space for a given trace by using registered choices. + This function simply expands candidate space as long as the knob's constraint satisfies. + To reduce the search space, a developer may expand each choice with smart search method. + (e.g., genetic search, multi-armed bandit) + Note, each pass generates candidates without worrying about the interaction with other passes. + i.e., it only uses its incoming trace/IRModule and Choices for candidate generation. + This will help alleviating the complexity of joint-optimization significantly. + - consideration of interaction between optimizations has known to be extremely difficult. + + Parameters + ---------- + knobs : List[Knob] + List of Knobs to consider to generate candidate for input trace. + trace: Trace + Input trace. + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + + candidates = [trace] + # Iterate over every decision + for knob in knobs: + num = len(candidates) + for _ in range(num): + cur_trace = candidates.pop(0) + for decision in knob.choices.keys(): + choice = knob.choices[decision] + # Generate new candidate when this condition satisfies. + if choice.check_constr(cur_trace.out_mod): + new_trace = cur_trace.deepcopy() + new_trace.add(knob, decision) + candidates.append(new_trace) + + # Expand candidates by using eval passes if provided. This will enable joint-optimization. + if eval_passes: + candidates = default_consider_eval_passes(candidates, eval_passes) + return candidates + + +@register_func("relax.tuning_api.default_consider_eval_passes") +def default_consider_eval_passes( + init_candidates: List[Trace], eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to update traces with eval passes. + It visits each eval_pass in dfs order in transform.Sequential() and + returns the best possible candidate trace for each candidate. + + Parameters + ---------- + init_candidates: List[Trace] + Initial candidates + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + if not eval_passes: + return init_candidates + + eval_passes = list(eval_passes) if not isinstance(eval_passes, list) else eval_passes + ctx = PassContext.current() + candidates = [] + + for trace in init_candidates: + ctx.push_trace(trace) + tvm.transform.Sequential(eval_passes)(trace.out_mod) + new_trace = ctx.pop_trace() + # A new trace contains the best decisions in eval_passes + candidates.append(new_trace) + + return candidates + + +@register_func("relax.tuning_api.default_evaluate") +def default_evaluate( + candidates: List[Trace], + target_str: str, + params: Optional[Dict[str, np.ndarray]] = None, + builder: Optional[meta_schedule.builder.Builder] = None, + runner: Optional[meta_schedule.runner.Runner] = None, +) -> None: + """ + Default function to evaluate a set of candidate traces by using MetaSchedule builder/runner. + + Parameters + ---------- + candidates: List[Trace] + List of traces to evaluate. + target_str: str, + Compilation target (e.g., llvm, cuda). + params: Optional[Dict[str, np.ndarray]] + Params to bind. + builder: Optional[meta_schedule.builder.Builder] + builder function. If not provided, default local builder will be used. + runner: Optional[meta_schedule.runner.Runner] + runner function. If not provided, default local runner will be used. + """ + + ctx = PassContext.current() + target = tvm.target.Target(target_str) + database = PassContext.current().get_tuning_api_database() + # Setup default local builder if not provided + if builder is None: + + def relax_build( + mod: IRModule, + target: tvm.target.Target, + params: Optional[Dict[str, np.ndarray]], + ): + if params: + mod = tvm.relax.transform.BindParams("main", params)(mod) + relax_exec = tvm.relax.vm.build(mod, target) + return relax_exec.mod + + builder = LocalBuilder(f_build=relax_build) + + # Setup default local runner if not provided + if runner is None: + + def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): + relax_exec = tvm.relax.vm.Executable(rt_mod) + relax_vm = tvm.relax.VirtualMachine(exec=relax_exec, device=device) + + evaluator = relax_vm.module.time_evaluator( + func_name="main", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + + return costs + + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=3, repeat=5, min_repeat_ms=100, enable_cpu_cache_flush=False + ), + f_run_evaluator=relax_eval_func, + ) + + # set up clean up function + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + assert f_clean_build + + # Keep track of number of evaluations (mostly for the debugging purpose) + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement + if candidate.perf != -1: + continue + + # Evaluate candidates + num_evals += 1 + mod = candidate.out_mod + workload = database.commit_workload(mod) + + # If this workload and target pair has measured before, fetch its data. + if database.has_measurement_record(workload, target): + run_secs = database.get_measurement_record(workload, target) + # Otherwise, measure it. + else: + # Build candidate + (builder_result,) = builder.build([BuilderInput(mod, target, params)]) + + if builder_result.artifact_path is None: + # Build error + # Assign the worst performance and move on to the next candidate. + logger.warning(builder_result.error_msg) + run_secs = [1e100] + else: + # If build passes, set up runner input and measure the performance. + args_info = [ + TensorInfo( + shape=[int(i) for i in p.struct_info.shape], dtype=p.struct_info.dtype + ) + for p in mod["main"].params + ] # convert list[Var] to list[TensorInfo] + runner_input = RunnerInput( + builder_result.artifact_path, target_str, args_info=args_info + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + run_secs = runner_result.run_secs + # Runtime error + # Assign the worst performance and move on to the next candidate. + if runner_result.error_msg is not None: + logger.warning(runner_result.error_msg) + run_secs = [1e100] + + database.commit_measurement_record(workload, target, run_secs) + + # Clean up the artifact + f_clean_build(builder_result.artifact_path) + + # For valid measurments, compute the average and update the trace performance. + perfs = [] + for result in run_secs: + if isinstance(result, tvm.tir.FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + perfs.append(result) + + # Store the evaluation result + candidate.set_perf(np.mean(perfs)) + + ctx.inc_num_evals(num_evals) + + +def select_best_candidate(candidates: List[Trace]) -> Trace: + """ + Select the best trace. + + Parameters + ---------- + candidates: List[Trace] + Candidate traces + + Return + ---------- + best_trace: Trace + Trace with the best performance + """ + best_perf, best_trace = sys.maxsize, None + for candidate in candidates: + avg = candidate.perf + # Select best one + if best_perf > avg: + best_perf = avg + best_trace = candidate + return best_trace diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py new file mode 100644 index 000000000000..67b81ba7e99c --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API primitives""" + +from typing import Callable, Union, Dict, List, Optional, Sequence +import logging +import tvm +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.relax import Expr +from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm +from tvm._ffi import register_object +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.Choice") +class Choice(Object): + """ + A TVM object Choice that maintains a set of transformation and constraint function keys. + Corresponding functions should be registered as PackedFunc with these keys. + Transformation function will be applied when constraint function returns true. + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + transform_func_args : Optional[List] + Arguments for transformation function. + constr_func_key : Optional[str] + Key for constraint function. + constr_func_args : Optional[List] + Arguments for constraint function. + + Examples + -------- + The following code block defines a Choice. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + @tvm.register_func("relax.tuning_api.test.constr_func") + def constr(mod): + return len(mod.functions) == 3 + # Define a choice to apply constant folding only when IRModule has three functions. + choice = Choice( + transform_func_key = "relax.tuning_api.test.transform_func", + constr_func_key = "relax.tuning_api.test.constr_func" + ) + """ + + def __init__( + self, + transform_func_key: Optional[str] = None, + transform_func_args: Optional[List] = None, + constr_func_key: Optional[str] = None, + constr_func_args: Optional[List] = None, + ): + """Constructor + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + + f_tramsform_args: Optional[List] + Arguments for transformation function. + + constr_func_key : Optional[str] + Key for constraint function. + + constr_func_args: Optional[List] + Arguments for constraint function. + """ + + if transform_func_key is None: + transform_func_key = "relax.tuning_api.Choice.default_transform_func" + + if transform_func_args is None: + transform_func_args = [] + + if constr_func_key is None: + constr_func_key = "relax.tuning_api.Choice.default_constr_func" + + if constr_func_args is None: + constr_func_args = [] + + self.__init_handle_by_constructor__( + _ffi_api.Choice, # type: ignore + transform_func_key, + transform_func_args, + constr_func_key, + constr_func_args, # type: ignore # pylint: disable=no-member + ) + + def get_transform_func(self) -> Callable: + """Getter for transform_func + Returns + ------- + ret: Callable + registered transformation function + """ + return _ffi_api.ChoiceGetTransformFunc(self) # type: ignore + + def get_constr_func(self) -> Callable: + """Getter for constr_func + Returns + ------- + ret: Callable + registered constraint function + """ + return _ffi_api.ChoiceGetConstrFunc(self) # type: ignore + + def apply_transform_func(self, mod: IRModule) -> IRModule: + """Perform transform_func with its arguments + Returns + ------- + ret: IRModule + Transformed IRModule + """ + return _ffi_api.ChoiceApplyTransformFunc(self, mod) # type: ignore + + def check_constr(self, mod: IRModule) -> bool: + """Perform constr_func with its arguments + Returns + ------- + ret: bool + Returns whether the IRModule satisfies the constraint or not + """ + return _ffi_api.ChoiceCheckConstr(self, mod) # type: ignore + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.ChoiceAsJSON(self) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Choice": + """Create Choice from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Choice serialized with JSON + + Return + ---------- + choice: Choice + Deserialized choice + """ + return _ffi_api.ChoiceFromJSON(json_obj) # type: ignore + + def deepcopy(self): + return Choice.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Knob") +class Knob(Object): + """ + A TVM object Knob that maintains a set of valid Choices. + By using Knobs, a tuning pass can generate candidates and define the search space. + Parameters + ---------- + name : str + Name of the knob. + + choices: Union[List[Choice], Dict[str, Choice]] + A list of valid choices + + Examples + -------- + The following code block defines a Knob. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + choices = {"apply": Choice("relax.tuning_api.test.transform_func"), "noapply": Choice()} + # A knob manages a set of its valid choices + knob = Knob("MockTuningKnob", choices) + """ + + def __init__(self, name: str, choices: Union[List[Choice], Dict[str, Choice]]): + """Constructor.""" + if isinstance(choices, list): + choices = {str(idx): val for idx, val in enumerate(choices)} + + self.__init_handle_by_constructor__( + _ffi_api.Knob, name, choices # type: ignore # pylint: disable=no-member + ) + + def verify(self, decision: Union[str, int]) -> bool: + """Verify if the decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobIsValidDecision(self, decision) # type: ignore + + def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule: + """Get choice if a decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobApply(self, mod, decision) # type: ignore + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.KnobAsJSON(self) # type: ignore + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Knob": + """Create Knob from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Knob serialized with JSON + + Return + ---------- + knob: Knob + Deserialized knob + """ + return _ffi_api.KnobFromJSON(json_obj) # type: ignore + + def __str__(self) -> str: + msg = f"{self.name} (# of choices: {len(self.choices)})\n" + for name, choice in self.choices.items(): + msg += f" - {name}: {choice}\n" + return msg + + def deepcopy(self): + return Knob.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Trace") +class Trace(Object): + """ + A TVM object Trace logs the history of transformations (decisions). + Parameters + ---------- + in_mod : IRModule + Input IRModule. + knobs: Optional[List[Knob]] + A list of knobs applied in the trace. + decisions: Optional[Sequence[Union[str, int]]] + A list of decisions made for each knob + + Examples + -------- + The following code block defines a Trace. + + .. code-block:: python + + trace = Trace(mod, [knob1, knob2, knob3], ["c1", "c0", "c3"]) + assert trace.size == 3 # Length of history. + # 'out' contains IRModule that applies transformations in the trace. + out: IRModule = trace.add(knob4, "c2") + assert trace.size == 4 # Length of history. + trace.set_perf(0.03) # Set the performance number of the trace. + """ + + def __init__( + self, + in_mod: IRModule, + knobs: Optional[List[Knob]] = None, + decisions: Optional[Sequence[Union[str, int]]] = None, + ): + """Constructor.""" + knobs = knobs if knobs else list() + decisions = ( + [str(v) if isinstance(v, int) else v for v in decisions] if decisions else list() + ) + self.__init_handle_by_constructor__( + _ffi_api.Trace, in_mod, knobs, decisions # type: ignore # pylint: disable=no-member + ) + + def verify(self) -> bool: + """Verify if current history is valid.""" + return _ffi_api.TraceVerify() # type: ignore + + def add(self, knob: Knob, decision: Union[str, int]) -> IRModule: + """Add & Apply new decision (with knob).""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.TraceAdd(self, knob, decision) # type: ignore + + def set_perf(self, perf: float) -> None: + """Set performance number for the trace.""" + return _ffi_api.TraceSetPerf(self, perf) # type: ignore + + def set_out_mod(self, mod: IRModule) -> None: + """Set out_mod for the trace.""" + return _ffi_api.TraceSetOutMod(self, mod) # type: ignore + + def as_json(self, include_irmod: bool = True) -> JSON_TYPE: + """Serialize the trace as a JSON-style object. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json: JSON_TYPE + The JSON-style object. + """ + obj = _ffi_api.TraceAsJSON(self, include_irmod) # type: ignore + return _json_from_tvm(obj) + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Trace": + """Create Trace from JSON obj. + + Parameters + ---------- + json_obj: JSON_TYPE + Trace serialized with JSON. + + Return + ---------- + trace: Trace + Deserialized trace. + """ + return _ffi_api.TraceFromJSON(json_obj) # type: ignore + + def __str__(self) -> str: + n = len(self.knobs) + msg = f"Trace length: {n}\n" + for idx in range(n): + msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n" + return msg + + def deepcopy(self) -> "Trace": + new_in_mod = deepcopy_irmodule(self.in_mod) + new_knobs = [knob.deepcopy() for knob in self.knobs] + new_decisions = [str(decision) for decision in self.decisions] + new_trace = Trace(new_in_mod, new_knobs, new_decisions) + new_out_mod = deepcopy_irmodule(self.out_mod) + new_trace.set_out_mod(new_out_mod) + return new_trace + + +def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: + """ + Getter for a trace wrapper. + + Parameters + ---------- + in_: Union[Trace, IRModule, Expr] + Input entity + Return + ---------- + wrapped: Trace + Traced entity + """ + if isinstance(in_, Trace): + return in_ + if isinstance(in_, IRModule): + return Trace(in_) + if isinstance(in_, Expr): # type: ignore + return Trace(tvm.IRModule.from_expr(in_)) + + raise Exception(f"Invalid input type for trace: {type(in_)}") + + +@tvm.register_func("relax.tuning_api.deepcopy_irmodule") +def deepcopy_irmodule(mod: IRModule) -> IRModule: + """ + Deepcopy for an IRModule. + Parameters + ---------- + mod: IRModule + input IRModule + Return + ---------- + copied_mod: IRModule + deep-copied IRModule + """ + func_save_json = tvm.get_global_func("node.SaveJSON") + func_load_json = tvm.get_global_func("node.LoadJSON") + new_mod = None + # Handle external modules separately if exist + # TODO(tvm-team): + # Serialization of IRModule with external mods is tricky. + # (1) External mod is runtime module. + # (2) Currently, `export_library` does not support serialization of + # runtime module without the host module + # Therefore, we simply pass around the compiled external modules without copy for now. + # Revisit later when we have a better solution. + if mod.attrs and "external_mods" in mod.attrs: + tmp_mod = mod.without_attr("external_mods") + new_mod = func_load_json(func_save_json(tmp_mod)) + new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"]) + else: + new_mod = func_load_json(func_save_json(mod)) + + return new_mod diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 9450ade34e67..94d211a7fb4c 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -70,6 +70,7 @@ def prim_func_pass( opt_level: int = None, name: Optional[str] = None, required: Optional[List[str]] = None, + traceable=False, ) -> Union[Callable, PrimFuncPass]: """Decorate a function pass. @@ -148,7 +149,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 66b06e6b505d..619526d0b56b 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -341,11 +342,13 @@ class ModulePass : public Pass { TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required, + bool traceable) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); + pass_info->traceable = std::move(traceable); data_ = std::move(pass_info); } @@ -401,7 +404,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); - PassInfo pass_info = PassInfo(0, std::move(name), {}); + PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false); n->pass_info = std::move(pass_info); data_ = std::move(n); } @@ -444,26 +447,61 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; continue; } + // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); } - mod = pass(std::move(mod), pass_ctx); + + // This handles passes that does not use Relax tuning API (untraceable passes). + // We make untraceable passes trackable when pass context has a trace (trace mode). + // When passes to trace (make_traceable) is provided from users, we only make them trackable. + if (pass_ctx->trace_stack.size() && !pass_info->traceable && + (!pass_ctx->make_traceable.defined() || + pass_ctx->make_traceable.value().count(pass_info->name))) { + // TODO(tvm-team): Currently, there are some inconsistency in the pass registration. + // 1. Some passes are not registered in ffi registry. + // 2. Some passes do not follow the name convention. (e.g., = + ) + + // Due to these problems, serialization with non-traceable passes is handled in a hacky way + // now. Find a systematic way to identify such inconsistencies and fix them. + + // In the future, we should pass the ffi key for a pass by deducing from its name. + String transform_func_key = "relax.tuning_api.Choice.default_transform_func"; + String constr_func_key = "relax.tuning_api.Choice.default_constr_func"; + + relax::Knob knob = relax::Knob( + pass_info->name, {{"Applied", relax::Choice(transform_func_key, Array(), + constr_func_key, Array())}}); + + // Add new decision to the trace at the top of the stack. + auto trace = Downcast(pass_ctx->trace_stack.back()); + trace->Add(knob, "Applied"); + // In the future, we should just have + // mod = trace->Add(knob, "enabled"); + // instead of the two lines below. + mod = pass(std::move(mod), pass_ctx); + trace->SetOutMod(mod); + + } else { + mod = pass(std::move(mod), pass_ctx); + } } return mod; } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return ModulePass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") - .set_body_typed([](int opt_level, String name, tvm::Array required) { - return PassInfo(opt_level, name, required); + .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { + return PassInfo(opt_level, name, required, traceable); }); TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -514,7 +552,8 @@ TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValu int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - PassInfo pass_info = PassInfo(opt_level, name, required); + bool traceable = args[4]; + PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); }); @@ -537,7 +576,9 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, - Optional> config) { + Optional> config, Array trace_stack, + Optional> make_traceable, int num_evals, + Optional tuning_api_database) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -547,6 +588,10 @@ TVM_REGISTER_GLOBAL("transform.PassContext") if (config.defined()) { pctx->config = config.value(); } + pctx->trace_stack = std::move(trace_stack); + pctx->make_traceable = std::move(make_traceable); + pctx->num_evals = std::move(num_evals); + pctx->tuning_api_database = std::move(tuning_api_database); PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); @@ -562,7 +607,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tconfig: " << node->config; + p->stream << "\tconfig: " << node->config << "\n"; + p->stream << "\ttrace stack: " << node->trace_stack; }); class PassContext::Internal { @@ -572,6 +618,22 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; +TVM_REGISTER_GLOBAL("transform.GetTraceStack") + .set_body_method(&PassContextNode::GetTraceStack); +TVM_REGISTER_GLOBAL("transform.PushTrace") + .set_body_method(&PassContextNode::PushTrace); +TVM_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); +TVM_REGISTER_GLOBAL("transform.GetTraceStackSize") + .set_body_method(&PassContextNode::GetTraceStackSize); +TVM_REGISTER_GLOBAL("transform.GetCurrentTrace") + .set_body_method(&PassContextNode::GetCurrentTrace); +TVM_REGISTER_GLOBAL("transform.SetNumEvals") + .set_body_method(&PassContextNode::SetNumEvals); +TVM_REGISTER_GLOBAL("transform.IncNumEvals") + .set_body_method(&PassContextNode::IncNumEvals); +TVM_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") + .set_body_method(&PassContextNode::GetTuningAPIDatabase); + TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); @@ -595,7 +657,7 @@ Pass PrintIR(String header, bool show_meta_data) { LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; }; - return CreateModulePass(pass_func, 0, "PrintIR", {}); + return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc new file mode 100644 index 000000000000..beb3950af1d1 --- /dev/null +++ b/src/relax/backend/task_extraction.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +using tvm::meta_schedule::ExtractedTask; + +/*! + * \brief Extract the Meta-Schedule tuning task from a given IRModule. + * \note + * 1. The task extractor is responsible for task deduplication. The + * deduplication is achieved by comparing structural hashes of PrimFuncs. + * 2. For a PrimFunc, the weight of its corresponding task is the number + * of times it called by op Call-TIR. Say in an IRModule there are three + * PrimFuncs `fn1`, `fn2` and `fn3` sharing the same structural hash. + * Suppose `fn1` is called by 5 Call-TIR ops among all Relax function, + * `fn2` is called by 3 Call-TIR and `fn3` is called by 5 Call-TIR. + * Then we will have a ExtractedTask for all three functions, whose weight + * is 5 + 3 + 2 = 10. + */ +class TaskExtractor : public ExprVisitor { + public: + static Array ExtractTask(IRModule mod, Target target) { + TaskExtractor extractor(mod, target); + // We go through each Relax function in the module. + for (const auto& kv : mod->functions) { + if (const auto* func = kv.second.as()) { + extractor(GetRef(func)); + } + } + return std::move(extractor.tasks_); + } + + private: + explicit TaskExtractor(IRModule mod, Target target) + : mod_(std::move(mod)), target_(std::move(target)) { + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // this logic should be changed accordingly. + if (!call->op.same_as(call_tir_op)) { + // Since the Relax function is of A-normal form, the arguments of this call cannot be another + // Calls. And hence we do not need to recurse into this Call. + return; + } + + // Do not extract external function + if (call->args[0].as()) { + return; + } + + const GlobalVar& global_var = Downcast(call->args[0]); + const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); + + auto it = func2task_.find(func); + if (it != func2task_.end()) { + it->second->weight += 1; + return; + } + + IRModule tir_mod = (*normalize_mod_func_)(func); + ExtractedTask task(/*task_name=*/global_var->name_hint, // + /*mod=*/tir_mod, // + /*target=*/target_, // + /*dispatched=*/{tir_mod}, // + /*weight=*/1); + tasks_.push_back(task); + func2task_.emplace(func, task); + } + + IRModule mod_; + Target target_; + Array tasks_; + std::unordered_map func2task_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target) { + return TaskExtractor::ExtractTask(std::move(mod), std::move(target)); + }); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index 1b077d8b887a..9f418bff5c6d 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -173,8 +173,8 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(pass_func, pass_info); } @@ -389,8 +389,8 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return DataflowBlockPass(pass_func, pass_info); } diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc new file mode 100644 index 000000000000..d444ba16654f --- /dev/null +++ b/src/relax/transform/meta_schedule.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/meta_schedule.cc + * \brief Pass for meta_schedule tuning + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +class MetaScheduleTuner { + public: + explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, + Map params = {}) + : target_(target), + work_dir_(work_dir), + max_trials_global_(max_trials_global), + params_(params) { + candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate"); + ICHECK(candgen_func_) << "Default candidate generation function is not found."; + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { + Trace trace = Downcast(ctx->GetCurrentTrace()); + ctx->PopTrace(); + Choice choice("tvm.meta_schedule.tune_relax", {params_, target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + Trace best_trace = candidates[0]; + ctx->PushTrace(best_trace); + // since we separate tuning from application, return original IRModule + return mod; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { + // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace + // stack. Revisit later when we collect more usecases. + Trace trace = Trace((*normalize_mod_func_)(f), {}, {}); + + Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_primfunc", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + // since we separate tuning from application, return original IRModule + return f; + } + + private: + Target target_; + String work_dir_; + Integer max_trials_global_; + Map params_; + const runtime::PackedFunc* candgen_func_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +Pass MetaScheduleApplyDatabase(Optional work_dir) { + using tvm::meta_schedule::Database; + Target target = Target::Current(false); + const runtime::PackedFunc* normalize_mod_func_ = + runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + Database database{nullptr}; + if (Database::Current().defined()) { + database = Database::Current().value(); + } else { + ICHECK(work_dir.defined()); + String path_workload = work_dir.value() + "/database_workload.json"; + String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload + << ", Tuning records at: " << path_tuning_record; + database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); + } + + Map result; + for (const auto& iter : mod->functions) { + GlobalVar gv = iter.first; + BaseFunc base_func = iter.second; + if (const auto* prim_func_node = base_func.as()) { + tir::PrimFunc prim_func = GetRef(prim_func_node); + + IRModule tir_mod = (*normalize_mod_func_)(prim_func); + if (Optional sch = database->QuerySchedule(tir_mod, target, gv->name_hint)) { + IRModule new_mod = sch.value()->mod(); + ICHECK_EQ(new_mod->functions.size(), 1); + BaseFunc new_base_func = (*new_mod->functions.begin()).second; + ICHECK(new_base_func->IsInstance()); + tir::PrimFunc new_prim_func = Downcast(new_base_func); + // copy the original attrs + new_prim_func = WithAttrs(std::move(new_prim_func), {prim_func->attrs->dict}); + result.Set(gv, new_prim_func); + continue; + } else { + LOG(WARNING) << "Tuning record is not found for primfunc: " << gv->name_hint; + } + } + result.Set(gv, base_func); + } + return IRModule(result, // functions + {}, // type_definitions + {}, // import_set + {}, // map + mod->attrs); // attrs); + }; + return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); +} + +Pass MetaScheduleTuneIRMod(Map params, String work_dir, + Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global, params).TuneIRMod(m, ctx); + }; + return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneIRModule", + /*required*/ {}, + /*traceable*/ true); +} + +Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = + [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global).TuneTIR(f, ctx); + }; + return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneTIR", + /*required*/ {}, + /*traceable*/ true); +} + +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") + .set_body_typed(MetaScheduleApplyDatabase); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc new file mode 100644 index 000000000000..0d239e5fbf81 --- /dev/null +++ b/src/relax/transform/tuning_api/database.cc @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/database.cc + * \brief Database of tuning APIs. + */ +#include + +#include +#include +#include + +#include "../../../meta_schedule/utils.h" + +namespace tvm { +namespace meta_schedule { + +void JSONFileAppendLine(const String& path, const std::string& line); +std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); + +} // namespace meta_schedule +} // namespace tvm + +namespace tvm { +namespace relax { + +TuningRecord::TuningRecord(Trace trace, Optional> run_secs) { + ObjectPtr n = make_object(); + n->trace = trace; + n->run_secs = run_secs; + this->data_ = n; +} + +ObjectRef TuningRecordNode::AsJSON(bool include_irmod) const { + return Array{trace->AsJSON(include_irmod), // + run_secs}; +} + +TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { + Trace trace{nullptr}; + Optional> run_secs{nullptr}; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 2); + // Load json[0] => trace + { + const ObjectRef& json_trace = json_array->at(0); + trace = Trace::FromJSON(json_trace); + } + + // Load json[1] => run_secs + if (json_array->at(1).defined()) { + run_secs = meta_schedule::AsFloatArray(json_array->at(1)); + } + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TuningRecord(trace, run_secs); +} + +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); + return a_time < b_time; + } +}; + +// TODO(tvm-team): Currently, we strictly treat each target separately. +// Since not every option in the target matters, this might be the overkill. +// Revisit this when we have better approach with target equality check. +inline std::string get_database_key(int workload_idx, Target target) { + return std::to_string(workload_idx) + "/" + target->str(); +} + +/*! \brief The default database implementation, which mimics two database tables with two files. + */ +class JSONDatabaseNode : public DatabaseNode { + public: + /*! \brief The path to the workload table */ + String path_workload; + /*! \brief The path to the tuning record table */ + String path_tuning_record; + /*! \brief The path to the measurement table */ + String path_measurement_record; + /*! \brief All the workloads in the database */ + std::unordered_map + workloads2idx_; + /*! \brief All the tuning records in the database */ + std::unordered_map> + tuning_records_; + + /*! \brief Measurement logs in the database */ + std::unordered_map> measurement_records_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("path_workload", &path_workload); + v->Visit("path_tuning_record", &path_tuning_record); + v->Visit("path_measurement_record", &path_measurement_record); + // `workloads2idx_` is not visited + // `tuning_records_` is not visited + // `measurement_records_` is not visited + } + + static constexpr const char* _type_key = "relax.tuning_api.JSONDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + + public: + bool HasWorkload(const IRModule& mod) { + return workloads2idx_.find(meta_schedule::Workload(mod, tvm::StructuralHash()(mod))) != + workloads2idx_.end(); + } + + bool HasMeasurementRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return measurement_records_.count(key) > 0; + } + + bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return tuning_records_.count(key) > 0; + } + + meta_schedule::Workload CommitWorkload(const IRModule& mod) { + // Try to insert `mod` into `workloads_` + decltype(this->workloads2idx_)::iterator it; + bool inserted = false; + std::tie(it, inserted) = + this->workloads2idx_.emplace(meta_schedule::Workload(mod, tvm::StructuralHash()(mod)), -1); + meta_schedule::Workload workload = it->first; + // If `mod` is new in `workloads2idx_`, append it to the workload file + if (inserted) { + it->second = static_cast(this->workloads2idx_.size()) - 1; + meta_schedule::JSONFileAppendLine(this->path_workload, + meta_schedule::JSONDumps(workload->AsJSON())); + } + return it->first; + } + + void CommitMeasurementRecord(const meta_schedule::Workload& workload, const Target& target, + const Array& run_secs) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + + if (measurement_records_[key].size() == 0) { + measurement_records_[key] = run_secs; + meta_schedule::JSONFileAppendLine(this->path_measurement_record, + meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), + run_secs // + })); + } else { + LOG(WARNING) << "Measurement record for " << key + << " already exists. Use the existing one instead."; + } + } + + void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) { + int workload_idx = this->workloads2idx_.at(workload); + // There may exist multiple tuning records (with different traces) for a single key pair. + std::string key = get_database_key(workload_idx, target); + this->tuning_records_[key].insert(record); + + meta_schedule::JSONFileAppendLine( + this->path_tuning_record, meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), record->AsJSON()})); + } + + Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } + Array results; + results.reserve(top_k); + int counter = 0; + int idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(idx, target); + for (const TuningRecord& record : this->tuning_records_[key]) { + results.push_back(record); + if (++counter == top_k) { + break; + } + } + + return results; + } + + Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) { + int workload_idx = this->workloads2idx_.at(workload); + return this->measurement_records_[get_database_key(workload_idx, target)]; + } +}; + +Database Database::JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing) { + int num_threads = std::thread::hardware_concurrency(); + ObjectPtr n = make_object(); + // Load `n->workloads2idx_` from `path_workload` + std::vector workloads; + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_workload, num_threads, allow_missing); + int n_objs = json_objs.size(); + n->workloads2idx_.reserve(n_objs); + workloads.reserve(n_objs); + for (int i = 0; i < n_objs; ++i) { + meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i]); + n->workloads2idx_.emplace(workload, i); + workloads.push_back(workload); + } + } + // Load `n->tuning_records_` from `path_tuning_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_tuning_record, num_threads, allow_missing); + + std::vector workload_idxs; + std::vector targets; + std::vector records; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + records.resize(size, TuningRecord{nullptr}); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + records[task_id] = TuningRecord::FromJSON(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + + for (int i = 0; i < size; i++) { + std::string key = get_database_key(workload_idxs[i], targets[i]); + n->tuning_records_[key].insert(records[i]); + } + } + + // Load `n->measuremet_log` from `path_measurement_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_measurement_record, num_threads, allow_missing); + std::vector workload_idxs; + std::vector targets; + std::vector> measurements; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + measurements.resize(size, Array({})); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + measurements[task_id] = meta_schedule::AsFloatArray(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + for (int i = 0; i < size; i++) { + n->measurement_records_[get_database_key(workload_idxs[i], targets[i])] = measurements[i]; + } + } + + n->path_workload = path_workload; + n->path_tuning_record = path_tuning_record; + n->path_measurement_record = path_measurement_record; + return Database(n); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(TuningRecordNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") + .set_body_typed([](Trace trace, Optional> run_secs) { + return TuningRecord(trace, run_secs); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); + +TVM_REGISTER_OBJECT_TYPE(DatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") + .set_body_method(&DatabaseNode::HasWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") + .set_body_method(&DatabaseNode::HasMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") + .set_body_method(&DatabaseNode::HasTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") + .set_body_method(&DatabaseNode::CommitMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") + .set_body_method(&DatabaseNode::CommitWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") + .set_body_method(&DatabaseNode::CommitTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK") + .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") + .set_body_method(&DatabaseNode::GetMeasurementRecord); + +TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc new file mode 100644 index 000000000000..ef4a3d41bdf0 --- /dev/null +++ b/src/relax/transform/tuning_api/primitives.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/primitives.cc + * \brief Primitives of tuning APIs. + */ + +#include + +#include "../../../meta_schedule/utils.h" +namespace tvm { +namespace relax { + +Choice::Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + ObjectPtr n = make_object(); + n->transform_func_key = std::move(transform_func_key); + n->transform_func_args = std::move(transform_func_args); + n->constr_func_key = std::move(constr_func_key); + n->constr_func_args = std::move(constr_func_args); + data_ = std::move(n); +} + +// TODO(sunggg): Currently, it only supports an array of primitive data types. +ObjectRef ChoiceNode::AsJSON() const { + Array json_transfrom_args, json_constr_args; + for (ObjectRef arg : this->transform_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_transfrom_args.push_back(String(b64_arg)); + } + for (ObjectRef arg : this->constr_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_constr_args.push_back(String(b64_arg)); + } + return Array{ + this->transform_func_key, + json_transfrom_args, + this->constr_func_key, + json_constr_args, + }; +} + +Choice Choice::FromJSON(const ObjectRef& json) { + // Parse `json` into `choice` + String transform_func_key, constr_func_key; + Array transform_func_args, constr_func_args; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 4); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + const auto* arr2 = arr->at(2).as(); + const auto* arr3 = arr->at(3).as(); + ICHECK(arr0 && arr1 && arr2 && arr3); + transform_func_key = GetRef(arr0); + { + transform_func_args.reserve(arr1->size()); + for (const ObjectRef& elem : *arr1) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + transform_func_args.push_back(arg); + } + } + constr_func_key = GetRef(arr2); + { + constr_func_args.reserve(arr3->size()); + for (const ObjectRef& elem : *arr3) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + constr_func_args.push_back(arg); + } + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); +} + +Knob::Knob(String name, Map choices) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->choices = std::move(choices); + data_ = std::move(n); +} + +ObjectRef KnobNode::AsJSON() const { + Map json_choices; + for (auto const& x : choices) { + json_choices.Set(x.first, x.second->AsJSON()); + } + return Array{ + /* 0: name */ std::move(name), + /* 1: choices */ std::move(json_choices), + }; +} + +Knob Knob::FromJSON(const ObjectRef& json) { + // Parse `json` into `name` and `choices` + String name; + Map choices; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 2); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + name = GetRef(arr0); + for (auto const& x : GetRef>(arr1)) { + String decision = x.first; + Choice choice = Choice::FromJSON(x.second); + choices.Set(decision, choice); + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Knob(name, choices); +} + +Trace::Trace() { data_ = make_object(); } + +Trace::Trace(IRModule in_mod, Array knobs, Array decisions) { + ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match"; + // Deep-copy IRModule + auto func_deepcopy = runtime::Registry::Get("relax.tuning_api.deepcopy_irmodule"); + ICHECK(func_deepcopy); + IRModule out_mod = (*func_deepcopy)(in_mod); + // Apply the decision history if provided + int size = knobs.size(); + for (int i = 0; i < size; i++) { + out_mod = knobs[i]->Apply(out_mod, decisions[i]); + } + + ObjectPtr n = make_object(); + n->in_mod = std::move(in_mod); + n->out_mod = std::move(out_mod); + n->knobs = std::move(knobs); + n->decisions = std::move(decisions); + n->size = std::move(size); + data_ = std::move(n); +} + +ObjectRef TraceNode::AsJSON(bool include_in_mod) const { + ICHECK(this->Verify()) << "Trace should be valid"; + + Array json_knobs; + Array json_decisions; + + int size = this->size; + json_knobs.reserve(size); + json_decisions.reserve(size); + + for (int i = 0; i < size; i++) { + const Knob& knob = this->knobs[i]; + const String& decision = this->decisions[i]; + + json_knobs.push_back(knob->AsJSON()); + json_decisions.push_back(decision); + } + if (include_in_mod) { + std::string json_mod = tvm::SaveJSON(this->in_mod); + std::string b64_mod = meta_schedule::Base64Encode(json_mod); + return Array{json_knobs, json_decisions, String(b64_mod)}; + } else { + return Array{json_knobs, json_decisions}; + } +} + +Trace Trace::FromJSON(const ObjectRef& json) { + // Parse `json` into `trace` + IRModule in_mod; + Array knobs; + Array decisions; + try { + const ArrayNode* arr = json.as(); + // A trace will have 2 or 3 entries depending on `include_irmod` parameter. + ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); + + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + + for (const ObjectRef& elem : *arr0) { + knobs.push_back(Knob::FromJSON(elem)); + } + + for (const ObjectRef& elem : *arr1) { + decisions.push_back(Downcast(elem)); + } + + // When `include_irmod = true` + if (arr->size() == 3) { + const auto* arr2 = arr->at(2).as(); + String b64_mod = GetRef(arr2); + ICHECK(arr2); + std::string json_mod = meta_schedule::Base64Decode(b64_mod); + in_mod = Downcast(LoadJSON(json_mod)); + } + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Malformed Trace format - " << json; + throw; + } + return Trace(in_mod, knobs, decisions); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(ChoiceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Choice") + .set_body_typed([](String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") + .set_body_method(&ChoiceNode::GetTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") + .set_body_method(&ChoiceNode::GetConstrFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") + .set_body_method(&ChoiceNode::ApplyTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") + .set_body_method(&ChoiceNode::CheckConstr); + +TVM_REGISTER_NODE_TYPE(KnobNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Knob") + .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") + .set_body_method(&KnobNode::IsValidDecision); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); + +TVM_REGISTER_NODE_TYPE(TraceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Trace") + .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { + return Trace(in_mod, knobs, decisions); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod") + .set_body_method(&TraceNode::SetOutMod); + +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); +} // namespace relax +} // namespace tvm diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index fc1f3a15077e..dd31a1f7367d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -154,8 +154,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(pass_func, pass_info); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index d2eb48073f7d..a152bbe9c3cb 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -950,7 +950,7 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E }); Pass InferType() { - auto pass_info = PassInfo(0, "InferType", {}); + auto pass_info = PassInfo(0, "InferType", {}, /* trace */ false); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { // Execute the pass function and return a new module. diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 4c59a1767372..781a0ecd7c3d 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -115,8 +115,8 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return PrimFuncPass(pass_func, pass_info); } diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py new file mode 100644 index 000000000000..ff695b9436a3 --- /dev/null +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import tvm +import tvm.testing +import tvm.meta_schedule as ms +from tvm import relax +from tvm.ir import transform +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax.transform.tuning_api import Trace +from tvm.script import relax as R +from tvm.script import tir as T + +target = tvm.target.Target("llvm --num-cores=16") + + +@tvm.script.ir_module +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +# TODO(@sunggg): determine how to pass MS database object across different passes. +# PassContext might be an option, but we already have TuningAPI database. +# (MS database and TuningAPI database will be unified in the future) +# For now, we only support default JSON database config. +def test_ms_tuning_irmodule(): + + mod = InputModule + assert isinstance(mod, IRModule) + + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +def test_ms_tuning_primfunc(): + mod = InputModule + assert isinstance(mod, IRModule) + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneTIR( + work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + # TODO (@sunggg): Need to determine how to track subgraph-level tuning traces. + # Currently, we don't track this so the trace size. Revisit this later. + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py new file mode 100644 index 000000000000..b12ff016705d --- /dev/null +++ b/tests/python/relax/test_tuning_api.py @@ -0,0 +1,781 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np +import os.path as osp +import tempfile +from typing import List +from math import isclose + +import tvm +from tvm import ir +from tvm.ir import transform +from tvm.ir.transform import PassContext +from tvm.ir.module import IRModule +from tvm.script import tir as T, relax as R +from tvm import relax +from tvm.relax.expr import Expr, DataflowBlock, Function +from tvm.relax.transform.tuning_api import ( + Choice, + Knob, + Trace, + TuningRecord, + JSONDatabase, + default_generate_candidate, + default_consider_eval_passes, + default_evaluate, + select_best_candidate, + get_trace, +) + + +@tvm.script.ir_module +class TestModule: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + # Input IRModule. + @R.function + def before(c0: R.Tensor((16, 16), "int32")): + lv0 = R.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="int32")) + return lv0 + + # Expected IRModule after transformation. + @R.function + def expected(c1: R.Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + +def gen_mod(mod, name, binding): + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main. + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +# Setup for simple testing with IRModule. +def setup_test(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + return gen_mod(mod, "before", {}) + + +# Setup for testing with constant folding. +def setup_test_const_folding(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(mod, "before", {"c0": c0_np}) + expected = gen_mod(mod, "expected", {"c1": c1_np}) + + return before, expected + + +# Define a choice by using FoldConstant pass. +@tvm.register_func("testing.apply_fold_constant") +def apply_fold_constant(mod): + return relax.transform.FoldConstant()(mod) + + +@tvm.register_func("testing.add_global_symbol") +def add_global_symbol(mod, func_name, global_symbol): + mod[func_name] = mod[func_name].with_attr("global_symbol", global_symbol) + return mod + + +@tvm.register_func("testing.check_num_functions") +def check_num_funcs(mod, N): + # Explicit type specification is necessary. + # Otherwise, PackedFunc cannot derive the return type correctly. + # e.g., Check failed: type_code_ == kDLInt (8 vs. 0) : expected int but got Object + return bool(len(mod.functions) == N) + + +def test_choice(): + # Test setup. + ( + before, + expected, + ) = setup_test_const_folding() + + # Without any argument, default setting will be used for both transformation and constraint functions. + # default transformation function will return the original IRModule without any change. + choice = Choice( + # - transform_func_key="relax.tuning_api.Choice.default_transform_func" + # - constr_func_key="relax.tuning_api.Choice.default_constr_func") + ) + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, before) + + choice = Choice("testing.apply_fold_constant") + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, expected) + + # Create a choice that tags global symbol onto target function. + choice = Choice("testing.add_global_symbol", ["addone", "test-symbol"]) + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The transformation should be applied with Copy-On-Write. + # So, the original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test choice with impossible constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "addone" + + # Test choice with the proper constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test roundtrip. + # Export as JSON. + json_obj = choice.as_json() + # Import JSON. + new_choice = Choice.from_json(json_obj) + # Test imported choice + after = new_choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + +def test_knob(): + # Test setup. + before, expected = setup_test_const_folding() + + # Users can define a set of choices with list. + choices = [ + Choice("testing.apply_fold_constant"), + Choice(), + ] + + # Define knob. + knob = Knob("TestKnob", choices) + # Check the sanity of decision space. + assert knob.verify(0) + assert knob.verify(1) + assert not knob.verify(3) + + # Check the sanity of each decision. + after_apply = knob.apply(before, 0) + after_noapply = knob.apply(before, 1) + + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + + # Users can define a set of choices with dict. + choices = { + "apply": Choice("testing.apply_fold_constant"), + "noapply": Choice(), + "apply_with_impossible_constr": Choice( + transform_func_key="testing.apply_fold_constant", + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ), + } + # Define knob. + knob = Knob("TestKnob", choices) + assert knob.verify("apply") + assert knob.verify("noapply") + assert knob.verify("apply_with_impossible_constr") + assert not knob.verify("INVLAID") + + after_apply = knob.apply(before, "apply") + after_noapply = knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + # Test roundtrip. + # Export as JSON. + json_obj = knob.as_json() + # Import JSON. + new_knob = Knob.from_json(json_obj) + assert new_knob.name == knob.name + # Test imported knob + assert new_knob.verify("apply") + assert new_knob.verify("noapply") + assert new_knob.verify("apply_with_impossible_constr") + assert not new_knob.verify("INVLAID") + + after_apply = new_knob.apply(before, "apply") + after_noapply = new_knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + +def test_trace(): + before, expected = setup_test_const_folding() + + # Define choices and its knob. + choices = { + "apply": Choice( + transform_func_key="testing.apply_fold_constant", + transform_func_args=[], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ), + "noapply": Choice(), + } + knob = Knob("TestKnob", choices) + + # Define a Trace with empty decision (transformation) history. + trace = Trace(before) + assert trace.size == 0 + + # Define a Trace with single decision (transformation) history. + trace = Trace(before, [knob], ["noapply"]) + assert trace.size == 1 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + + # Add a new knob and its decision to the trace. + # It will update the current trace and returns its new output IRModule. + out: IRModule = trace.add(knob, "noapply") + assert trace.size == 2 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + tvm.ir.assert_structural_equal(out, before) + # Assume we assign arbitrary performance number. + trace.set_perf(100) + assert trace.perf == 100 + + # Add a new knob and its decision to the trace. + out: IRModule = trace.add(knob, "apply") + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, expected) + tvm.ir.assert_structural_equal(out, expected) + + assert trace.size == 3 + # Should be initalized when new knob is applied. + assert trace.perf == -1 + + # Test roundtrip. + # Export as JSON. + json_obj = trace.as_json() + # Import JSON. + new_trace = Trace.from_json(json_obj) + tvm.ir.assert_structural_equal(trace.in_mod, new_trace.in_mod) + assert str(trace) == str(new_trace) + assert new_trace.size == 3 + tvm.ir.assert_structural_equal(trace.out_mod, new_trace.out_mod) + + +def test_trace_wrapper(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + assert isinstance(Trace(mod), Trace) + assert isinstance(get_trace(mod), Trace) + assert isinstance(get_trace(mod["main"]), Trace) + assert isinstance(get_trace(mod["addone"]), Trace) + + +def create_tmp_database(tmpdir: str) -> JSONDatabase: + path_workload = osp.join(tmpdir, "workloads.json") + path_tuning_record = osp.join(tmpdir, "tuning_records.json") + path_measurement_record = osp.join(tmpdir, "measurement_records.json") + return JSONDatabase(path_workload, path_tuning_record, path_measurement_record) + + +def test_database(): + def equal_measurement_record(a: List[float], b: List[float]): + assert len(a) == len(b) + for i in range(len(a)): + assert isclose(a[i], b[i], rel_tol=1e-5) + + def equal_tuning_record(a: TuningRecord, b: TuningRecord): + assert str(a.trace) == str(b.trace) + equal_measurement_record(a.run_secs, b.run_secs) + + # Test setup. + ( + mod1, + mod2, + ) = setup_test_const_folding() + knob = Knob("test", {"noapply": Choice()}) + trace = Trace(mod1, [knob, knob], ["noapply", "noapply"]) + target = tvm.target.Target("llvm") + + # Test roundtrip + run_secs = [1.0, 0.9, 0.4] + tuning_record = TuningRecord( + trace, + run_secs, + ) + new_tuning_record = TuningRecord.from_json(json_obj=tuning_record.as_json()) + equal_tuning_record(tuning_record, new_tuning_record) + + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + workload1 = database.commit_workload(mod1) + + database.commit_measurement_record(workload1, target, run_secs) + new_run_secs1 = database.get_measurement_record(workload1, target) + equal_measurement_record(run_secs, new_run_secs1) + workload2 = database.commit_workload(mod2) + new_run_secs2 = database.get_measurement_record(workload2, target) + assert len(new_run_secs2) == 0 + + database.commit_tuning_record(workload1, target, tuning_record) + new_tuning_records = database.get_top_k(workload1, target, top_k=1) + assert len(new_tuning_records) == 1 + equal_tuning_record(tuning_record, new_tuning_records[0]) + new_tuning_records = database.get_top_k(workload1, target, top_k=0) + assert len(new_tuning_records) == 0 + + +def test_default_functions(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + + # Define choice, knob, trace. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + knob = Knob("TestKnob", choices) + trace = Trace(mod) + + # Launch a pass pipeline in trace mode. + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + with transform.PassContext(trace=trace, tuning_api_database=database): + # Default generation function expands every valid choice. + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + + # Default evaluate function uses MetaSchedule builder/runner. + # Since builder/runner are not provided, local builder/runner will be used. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Because these candidates are already evaluated, num_evals stays the same. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Test with multiple knobs + candidates = default_generate_candidate([knob, knob], trace) + assert len(candidates) == 4 + + # Launch new pass pipeline in trace mode. + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide tuning pass as an eval pass. + # Note that MockConstFoldingTuningPass() has its own generation function, evaluation function. + # Evaluation would be done in a tornament fashion. + # `default_consider_eval_passes` will convert candidates into the best version by considering eval_passes. + # For example, if we say candidates = [C1, C2] + # `default_consider_eval_passes` will return best form of C1 variant (C11 vs C12) and C2 variant (C21 vs C22) + # that can be generated by eval_passes. + # Assume C11 > C12, C21 < C22, + # new_candidates = [C11, C22] + new_candidates = default_consider_eval_passes( + candidates, [MockConstFoldingTuningPass(eval_passes=[])] + ) + + # len(candidates) == len(new candidates). + assert len(new_candidates) == 2 + # To find the best version of each candidate, it would take 4 evals (C11, C12, C21, C22). + assert PassContext.current().num_evals == 4 + + HeuristicPass = relax.transform.FoldConstant + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide heuristic pass as an eval pass. + new_candidates = default_consider_eval_passes(candidates, [HeuristicPass()]) + # Since heuristic pass has single decision, it won't need any tornament. + # new_candidates = [C11, C21] + assert len(new_candidates) == 2 + # We only conduct evaluation when its necessary (e.g., choose better candidate in tuning pass). + # Heuristic pass won't conduct any evaluation. + assert PassContext.current().num_evals == 0 + + +# TODO(sunggg): Do we need to serialize pass context as well? +def test_pass_context(): + before, expected = setup_test_const_folding() + HeuristicPass = relax.transform.FoldConstant + # FoldConstant implicitly performs TIR passes (prob for constant evaluation). + # If make_traceable is not provided, the pass infra will make every non-traceable pass traceable by default. + seq = transform.Sequential([HeuristicPass()]) + with transform.PassContext( + trace=Trace(before), + ): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + # The exact number of implicit passes might change as TVM develops more passes. + # As of today, this size returns 57. + assert PassContext.current().get_current_trace().size > 1 + + # We can explicitly specify which pass we want to keep track of. + with transform.PassContext(trace=Trace(before), make_traceable=["FoldConstant"]): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Check the functionality of trace stack. + with transform.PassContext(trace=Trace(before)): + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().push_trace(Trace(before)) + assert PassContext.current().get_trace_stack_size() == 2 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 0 + + +# Mock evaluation pass for testing. +# Assigns arbitrary performance number to each candidate. +def mock_evaluate(candidates: List[Trace], target_str: str, ctx: PassContext): + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement. + if candidate.perf != -1: + continue + + num_evals += 1 + # Assign arbitrary performance. + mock_perf = 100 - (ctx.num_evals + num_evals) + candidate.set_perf(mock_perf) + # Update number of evals for testing. + ctx.inc_num_evals(num_evals) + + +# Mock tuning pass that determines whether to apply relax.transform.FoldConstant(). +# Each pass invocation will generate two candidates for the incoming IRModule. +# In relax pass infra, each pass will define its own way of generating candidates and evaluating them without needing to know how other passes generate its candidate and evaluate them. +# This will significantly alleviate the development process since it is known to be HARD problem to consider the interaction with (potentially hundreds of) other passes. +@ir.transform.module_pass(opt_level=0, traceable=True) +class MockConstFoldingTuningPass(transform.Pass): + def __init__( + self, + f_generate_candidate=None, + f_evaluate=mock_evaluate, + eval_passes: List[transform.Pass] = None, + required: List[transform.Pass] = [], + ): + self.f_generate_candidate = ( + f_generate_candidate if f_generate_candidate else default_generate_candidate + ) + self.f_evaluate = f_evaluate if f_evaluate else default_evaluate + self.eval_passes = eval_passes + self.required = required + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = self.f_generate_candidate([knob], trace, self.eval_passes) + self.f_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + +def test_module_pass(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Tuning pass without any eval_pass. + mock_pass = MockConstFoldingTuningPass(eval_passes=[]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Heuristic pass should not affect the number of candidates. + mock_pass = MockConstFoldingTuningPass(eval_passes=[HeuristicPass()]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization will increase the search space in the combinatorial way + mock_pass = MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization can be nested. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Tuning pass and heuritic passes can be used together. + # Note that heuristic pass won't increate the search space (num_evals). + # It only increases the length of the trace. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[HeuristicPass(), HeuristicPass()]) + ] + ), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 6 + + # Users can mix-use sequential application and joint-application. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * (2 + 2 + 2) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 4 + + +def test_sequential(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Sequential with a single tuning pass should behave same with a single pass. + seq = transform.Sequential([MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Sequential pass should increase search space (num_evals) in additive manner. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Heuristic pass will not increase the search space. Just increase trace length. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 5 + + # Users can mix-use sequential application and joint-application. + seq = transform.Sequential( + [ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + ] + ) + ] + ), + ] + ), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == (2 * 2 * 2) + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 7 + + +def test_passes_with_mixed_granularities(): + @tvm.script.ir_module + class MockModule: + @R.function + def f1(x: R.Tensor(("m", "n"), "float32")): + with R.dataflow(): + lv0 = R.multiply(x, x) + gv0 = R.add(x, x) + R.output(gv0) + return gv0 + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): + with R.dataflow(): + lv0 = R.multiply(x, y) + gv0 = R.add(lv0, y) + R.output(gv0) + gv1 = R.multiply(x, y) + gv2 = R.add(gv1, y) + return (gv0, gv1, gv2) + + mod = MockModule + assert isinstance(mod, tvm.IRModule) + + # Helper function for tuning + def pass_func( + mod: IRModule, ctx: PassContext, eval_passes: List[transform.Pass] = None + ) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing + choices = [Choice(), Choice(), Choice()] + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = default_generate_candidate([knob], trace, eval_passes) + mock_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + @ir.transform.module_pass(opt_level=0, traceable=True) + def MockModulePass(mod: IRModule, ctx: PassContext) -> IRModule: + # Input granularity == Candidate granularity. + return pass_func(mod, ctx) + + @relax.transform.function_pass(opt_level=0, traceable=True) + def MockFunctionPass(func: Expr, mod: IRModule, ctx: PassContext) -> Function: + # Input granularity > Candidate granularity. + # Start trace with smaller granularity: IRModule->Function. + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something. + pass_func(mod, ctx) + # Pop tuned trace and recover the previous trace. + ctx.pop_trace() + return func + + @relax.transform.dataflowblock_pass(opt_level=0, traceable=True) + def MockDataflowBlockPass( + block: DataflowBlock, mod: IRModule, ctx: PassContext + ) -> DataflowBlock: + # TODO(sunggg): figure out how to create IRModule from DataflowBlock + # Provide random binding for now + x = relax.Var("x", R.Tensor([tvm.tir.Var("n", "int64")], "float32")) + seq_expr = relax.SeqExpr([block], x) + func = relax.Function([x], seq_expr, R.Tensor("float32", ndim=-1)) + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something + pass_func(mod, ctx) + ctx.pop_trace() + return block + + seq = transform.Sequential( + [ + MockModulePass, + MockFunctionPass, + MockDataflowBlockPass, + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=[]): + _ = seq(mod) + # Trace length and num eval can be different depending on how each function/dataflow block is treated. + assert PassContext.current().get_trace_stack_size() == 1 + + +if __name__ == "__main__": + pytest.main([__file__])