From 64631dd2ba6cf0657d6be4093d85d28ccfaafadb Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 5 Jun 2025 12:38:14 -0400 Subject: [PATCH] [REFACTOR] Phase out the relax tuning_api This PR phases out relax tuning api since we don't need the logic for per function tuning which is main need. --- include/tvm/ir/transform.h | 31 +- include/tvm/relax/tuning_api.h | 394 --------- python/tvm/ir/transform.py | 71 -- python/tvm/relax/transform/__init__.py | 4 +- .../relax/transform/tuning_api/__init__.py | 22 - .../relax/transform/tuning_api/_ffi_api.py | 19 - .../relax/transform/tuning_api/database.py | 273 ------ .../transform/tuning_api/default_functions.py | 307 ------- .../relax/transform/tuning_api/primitives.py | 419 ---------- src/ir/transform.cc | 57 +- src/relax/transform/meta_schedule.cc | 44 +- src/relax/transform/tuning_api/database.cc | 344 -------- src/relax/transform/tuning_api/primitives.cc | 272 ------ tests/python/relax/conftest.py | 4 - .../relax/test_transform_codegen_pass.py | 3 +- .../test_transform_meta_schedule_tuning.py | 3 - tests/python/relax/test_tuning_api.py | 781 ------------------ 17 files changed, 15 insertions(+), 3033 deletions(-) delete mode 100644 include/tvm/relax/tuning_api.h delete mode 100644 python/tvm/relax/transform/tuning_api/__init__.py delete mode 100644 python/tvm/relax/transform/tuning_api/_ffi_api.py delete mode 100644 python/tvm/relax/transform/tuning_api/database.py delete mode 100644 python/tvm/relax/transform/tuning_api/default_functions.py delete mode 100644 python/tvm/relax/transform/tuning_api/primitives.py delete mode 100644 src/relax/transform/tuning_api/database.cc delete mode 100644 src/relax/transform/tuning_api/primitives.cc delete mode 100644 tests/python/relax/test_tuning_api.py diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 8562fbaa8ff4..7d9ff940a816 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -90,16 +90,7 @@ class PassContextNode : public Object { /*! \brief A list of pass instrument implementations. */ Array instruments; - // TODO(@sunggg): Fix dependency issue in the header file and correct the types - // e.g., relax::trace, relax::database in tvm/relax/tuning_api.h - /*! \brief Trace stack for relax pass infra. */ - mutable Array trace_stack; - /*! \brief List of passes to be traced. If not defined, make every pass traceable. */ - Optional> make_traceable; - /*! \brief Number of evaluations conducted in the pass pipeline. */ - mutable int num_evals{0}; - /*! \brief Database for tuning API. */ - Optional tuning_api_database; + PassContextNode() = default; /*! @@ -138,27 +129,7 @@ class PassContextNode : public Object { v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); - v->Visit("trace_stack", &trace_stack); - v->Visit("make_traceable", &make_traceable); - v->Visit("num_evals", &num_evals); - v->Visit("tuning_api_daatabase", &tuning_api_database); - } - - Array GetTraceStack() { return trace_stack; } - void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); } - void PopTrace() { - ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; - trace_stack.pop_back(); } - int GetTraceStackSize() { return trace_stack.size(); } - ObjectRef GetCurrentTrace() { - ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; - return trace_stack.back(); - } - void SetNumEvals(int _num_evals) { num_evals = _num_evals; } - void IncNumEvals(int _num_evals) { num_evals += _num_evals; } - - Optional GetTuningAPIDatabase() { return tuning_api_database; } static constexpr const char* _type_key = "transform.PassContext"; static constexpr bool _type_has_method_sequal_reduce = false; diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h deleted file mode 100644 index c18a8cfb54a7..000000000000 --- a/include/tvm/relax/tuning_api.h +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relax/tuning_api.h - * \brief Relax Tuning Pass APIs. - */ -#ifndef TVM_RELAX_TUNING_API_H_ -#define TVM_RELAX_TUNING_API_H_ -#include -#include -#include - -#include -#include - -namespace tvm { -namespace relax { - -/*! \brief Helper function to unpack arguments in the array as parameters for the given packed - * function. */ -TVM_ALWAYS_INLINE ffi::Any CallPackedWithArgsInArray(const ffi::Function f, - const Array& args) { - size_t num_args = args.size(); - std::vector packed_args(num_args); - for (size_t i = 0; i < num_args; ++i) { - packed_args[i] = args[i]; - } - Any rv; - f.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv); - return rv; -} - -/*! \brief Choice manages a set of keys for transformation and constraint functions. */ -class ChoiceNode : public runtime::Object { - public: - /*! \brief ffi key for transformation function. */ - String transform_func_key; - /*! \brief ffi key for constraint function. */ - String constr_func_key; - Array transform_func_args; - Array constr_func_args; - - /*! \brief The default destructor. */ - virtual ~ChoiceNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("transform_func_key", &transform_func_key); - v->Visit("transform_func_args", &transform_func_args); - v->Visit("constr_func_key", &constr_func_key); - v->Visit("constr_func_args", &constr_func_args); - } - - /*! \brief Getter for constr_func. */ - const ffi::Function GetConstrFunc() { - const auto constr_func = tvm::ffi::Function::GetGlobal(constr_func_key); - ICHECK(constr_func.has_value()) << "constr_func_key is not registered: " << constr_func_key; - return *std::move(constr_func); - } - - /*! \brief Getter for transform_func. */ - const ffi::Function GetTransformFunc() { - auto transform_func = tvm::ffi::Function::GetGlobal(transform_func_key); - ICHECK(transform_func.has_value()) - << "transform_func_key is not registered: " << transform_func_key; - return *std::move(transform_func); - } - - /*! \brief Perform constr_func. */ - bool CheckConstr(IRModule mod) { - Array args(constr_func_args); - args.insert(args.begin(), GetRef(mod.CopyOnWrite())); - return CallPackedWithArgsInArray(GetConstrFunc(), args).cast(); - } - - /*! \brief Perform transform_func. */ - IRModule ApplyTransformFunc(IRModule mod) { - // Apply transformation when constraint is satisfied. - if (CheckConstr(mod)) { - Array args(transform_func_args); - args.insert(args.begin(), GetRef(mod.CopyOnWrite())); - return CallPackedWithArgsInArray(GetTransformFunc(), args).cast(); - } - return mod; - } - - /*! - * \brief Serialize Choice as a JSON-style object - * \return The JSON-style object - */ - ObjectRef AsJSON() const; - - static constexpr const char* _type_key = "relax.tuning_api.Choice"; - TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object); -}; - -/*! \brief Managed reference to ChoiceNode */ -class Choice : public runtime::ObjectRef { - public: - TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, - String constr_func_key, Array constr_func_args); - /*! \brief Deserialize JSON-style object into Choice */ - TVM_DLL static Choice FromJSON(const ObjectRef& json_obj); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode); -}; - -/*! \brief Knob manages a set of valid choices for an optimization. */ -class KnobNode : public runtime::Object { - public: - /*! \brief Name of the knob. */ - String name; - /*! \brief Decision space. */ - Map choices; - - /*! \brief The default destructor. */ - virtual ~KnobNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("choices", &choices); - } - - /*! \brief Check if a decision is valid. */ - bool IsValidDecision(String decision) { return choices.count(decision) > 0; } - - /*! \brief Apply decision if the constraint is satisfied. - Otherwise, return the original IRModule. - */ - IRModule Apply(IRModule mod, String decision) { - ICHECK(IsValidDecision(decision)) << "Invalid choice for this knob: " << decision; - return choices[decision]->ApplyTransformFunc(mod); - } - - /*! - * \brief Serialize Knob as a JSON-style object - * \return The JSON-style object - */ - ObjectRef AsJSON() const; - - static constexpr const char* _type_key = "relax.tuning_api.Knob"; - TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object); -}; - -/*! \brief Managed reference to KnobNode */ -class Knob : public runtime::ObjectRef { - public: - TVM_DLL explicit Knob(String name, Map choices); - /*! \brief Deserialize JSON-style object into Knob */ - TVM_DLL static Knob FromJSON(const ObjectRef& json_obj); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode); -}; - -/*! \brief Trace manages history of optimization decisions. */ -class TraceNode : public runtime::Object { - public: - /*! \brief Input IRModule. */ - IRModule in_mod; - /*! \brief Output IRModule. */ - mutable IRModule out_mod; - // TODO(sunggg): can we move knobs and decisions into private? - /*! \brief Knobs that are applied so far. */ - Array knobs; - /*! \brief Decisions made for the knobs. */ - Array decisions; - /*! \brief Performance of out_mod. */ - mutable double perf = -1; - /*! \brief Length of the decision history. */ - mutable int size = 0; - /*! \brief The default destructor. */ - virtual ~TraceNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("in_mod", &in_mod); - v->Visit("out_mod", &out_mod); - v->Visit("knobs", &knobs); - v->Visit("decisions", &decisions); - v->Visit("perf", &perf); - v->Visit("size", &size); - } - - /*! \brief Verify current decision history. */ - bool Verify() const { - if (knobs.size() != decisions.size()) return false; - int n = knobs.size(); - for (int i = 0; i < n; i++) { - if (!knobs[i]->IsValidDecision(decisions[i])) return false; - } - return true; - } - - /*! \brief Add a knob and its decision to the current trace. */ - IRModule Add(Knob knob, String decision) { - out_mod = knob->Apply(out_mod, decision); - knobs.push_back(knob); - decisions.push_back(decision); - // perf number should be initialized after new decision is applied. - perf = -1; - // increment history size. - size++; - return out_mod; - } - - /*! - * \brief Serialize Trace as a JSON-style object - * \param include_in_mod Boolean config to include input IRModule in the output. - * \return The JSON-style object - */ - ObjectRef AsJSON(bool include_in_mod = true) const; - - /*! \brief Set the performance. */ - void SetPerf(double _perf) { perf = _perf; } - /*! \brief Set output module. */ - void SetOutMod(IRModule mod_) { out_mod = mod_; } - - static constexpr const char* _type_key = "relax.tuning_api.Trace"; - TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object); -}; - -/*! \brief Managed reference to TraceNode */ -class Trace : public runtime::ObjectRef { - public: - /*! \brief Default constructor. Creating an empty trace. */ - Trace(); - /*! - * \brief Constructor. Creating a trace from existing knobs and their decisions - * \param in_mod Input IRModule - * \param knobs The knobs used - * \param decisions The decisions made in sampling - */ - TVM_DLL explicit Trace(IRModule in_mod, Array knobs, Array decisions); - /*! \brief Deserialize JSON-style object into Trace */ - TVM_DLL static Trace FromJSON(const ObjectRef& json_obj); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode); -}; - -/*! \brief The class of tuning records. */ -class TuningRecordNode : public runtime::Object { - public: - /*! \brief The trace tuned. */ - Trace trace; - /*! \brief The measurement record in seconds. */ - Optional> run_secs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("trace", &trace); - v->Visit("run_secs", &run_secs); - } - - static constexpr const char* _type_key = "relax.tuning_api.TuningRecord"; - TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); - - /*! - * \brief Export the tuning record to a JSON string. - * \param include_irmod Boolean config to include IRModules in the output. - * \return JSON object - */ - ObjectRef AsJSON(bool include_irmod = false) const; -}; - -/*! - * \brief The managed reference of TuningRecordNode. - * \sa TuningRecordNode - */ -class TuningRecord : public runtime::ObjectRef { - public: - /*! - \brief Constructor of a tuning record. - \param trace The trace of the tuning record. - \param run_secs The running time of the tuning record. - */ - TVM_DLL explicit TuningRecord(Trace trace, Optional> run_secs); - /*! - * \brief Create a tuning record from a json object. - * \param json_obj The json object. - * \return The tuning record created. - */ - TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); -}; - -/*! \brief The equality check for Workload */ -struct WorkloadEqual { - bool operator()(const meta_schedule::Workload& a, const meta_schedule::Workload& b) const { - return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); - } -}; - -/* \brief The abstract interface of database. */ -class DatabaseNode : public runtime::Object { - public: - /*! \brief Default destructor */ - virtual ~DatabaseNode() = default; - /*! - * \brief Check if the database has the given workload. - * \param mod The IRModule to be searched for. - * \return Whether the database has the given workload. - */ - virtual bool HasWorkload(const IRModule& mod) = 0; - /*! - * \brief Check if the database has a measurement record for the given workload and target pair. - * \param workload The workload to be searched for. - * \param target The target to be searched for. - * \return Whether the database has the measurement record for given workload and target pair. - */ - virtual bool HasMeasurementRecord(const meta_schedule::Workload& workload, - const Target& target) = 0; - /*! - * \brief Check if the database has a tuning record for the given workload and target pair. - * \param workload The workload to be searched for. - * \param target The target to be searched for. - * \return Whether the database has the tuning record for the given workload and target pair. - */ - virtual bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) = 0; - /*! - * \brief Look up or add workload to the database if missing. - * \param mod The IRModule to be searched for or added. - * \return The workload corresponding to the given IRModule. - */ - virtual meta_schedule::Workload CommitWorkload(const IRModule& mod) = 0; - /*! - * \brief Add a measurement record for a given pair of target and workload to the database. - * \param workload Workload to be searched for. - * \param target Target to be searched for. - * \param record Measurement record to be added. - */ - virtual void CommitMeasurementRecord(const meta_schedule::Workload& workload, - const Target& target, const Array& record) = 0; - /*! - * \brief Add a tuning record for a given pair of target and workload to the database. - * \param workload Workload to be searched for. - * \param target Target to be searched for. - * \param record Tuning record to be added. - */ - virtual void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, - const TuningRecord& record) = 0; - /*! - * \brief Get the top K tuning records of given workload and target from the database. - * \param workload The workload to be searched for. - * \param target Target to be searched for. - * \param top_k The number of top records to be returned. - * \return An array of top K tuning records for the given workload. - */ - virtual Array GetTopK(const meta_schedule::Workload& workload, const Target& target, - int top_k) = 0; - /*! - * \brief Get the measurement record of given workload and target from the database. - * \param workload The workload to be searched for. - * \param target Target to be searched for. - * \return Measurement. - */ - virtual Array GetMeasurementRecord(const meta_schedule::Workload& workload, - const Target target) = 0; - - static constexpr const char* _type_key = "relax.tuning_api.Database"; - TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); -}; - -/*! - * \brief Managed reference to DatabaseNode. - * \sa DatabaseNode - */ -class Database : public runtime::ObjectRef { - public: - /*! - * \brief Create a default database that uses JSON file for tuning records. - * \param path_workload The path to the workload table. - * \param path_tuning_record The path to the tuning record table. - * \param path_measurement_record The path to the measurement_record table. - * \param allow_missing Whether to create new file when the given path is not found. - */ - TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, - String path_measurement_record, bool allow_missing); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); -}; - -} // namespace relax -} // namespace tvm -#endif // TVM_RELAX_TUNING_API_H_ diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 45050d44af0b..b8f4c36c30c7 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -71,20 +71,6 @@ class PassContext(tvm.runtime.Object): config : Optional[Dict[str, Object]] Additional configurations for specific passes. - - trace: Optional[relax.tuning.Trace] - Initial trace for trace mode. - - trace_stack: Optional[List[relax.tuning_api.Trace]] - Initial trace stack for trace mode. - - make_traceable: Optional[List[str]] - List of passes to make traceable. - - num_evals: int - initial number of evaluations conducted in the pipeline. - - tuning_api_database: Optional[relax.tuning_api.JSONDatabase] """ def __init__( @@ -94,11 +80,6 @@ def __init__( disabled_pass=None, instruments=None, config=None, - trace=None, - trace_stack=None, - make_traceable=None, - num_evals=0, - tuning_api_database=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -112,13 +93,6 @@ def __init__( if not isinstance(instruments, (list, tuple)): raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") - # Convert to Map - # TODO(sunggg): Replace this to Set equivalent if exists - make_traceable = {name: True for name in make_traceable} if make_traceable else None - - if not trace_stack: - trace_stack = [trace] if trace else [] - config = config if config else None self.__init_handle_by_constructor__( _ffi_transform_api.PassContext, @@ -127,10 +101,6 @@ def __init__( disabled, instruments, config, - trace_stack, - make_traceable, - num_evals, - tuning_api_database, ) def __enter__(self): @@ -167,47 +137,6 @@ def list_configs(): """ return _ffi_transform_api.ListConfigs() - def push_trace(self, trace): - """Push a trace into the stack.""" - return _ffi_transform_api.PushTrace(self, trace) - - def pop_trace(self, return_current=True): - """Pop a topmost trace from the stack. - Returns - ------- - Trace : Optional[relax.tuning.Trace] - """ - if return_current: - cur_trace = self.get_current_trace() - _ffi_transform_api.PopTrace(self) - return cur_trace - - return _ffi_transform_api.PopTrace(self) - - def get_trace_stack(self): - """Get the current trace stack.""" - return _ffi_transform_api.GetTraceStack(self) - - def get_trace_stack_size(self): - """Get the size of current stack.""" - return _ffi_transform_api.GetTraceStackSize(self) - - def get_current_trace(self): - """Get the trace on the top of the stack.""" - return _ffi_transform_api.GetCurrentTrace(self) - - def set_num_evals(self, num: int): - """Set the number of evaluations conducted in the pipeline.""" - return _ffi_transform_api.SetNumEvals(self, num) - - def inc_num_evals(self, num: int): - """Increment the number of evaluations conducted in the pipeline.""" - return _ffi_transform_api.IncNumEvals(self, num) - - def get_tuning_api_database(self): - """Get tuning api database.""" - return _ffi_transform_api.GetTuningAPIDatabase(self) - @tvm.ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index ffdf31975a70..724921e5fee7 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Relax transformations. """ +"""Relax transformations.""" from .transform import ( AdjustMatmulOrder, @@ -98,4 +98,4 @@ from .remove_redundant_reshape import RemoveRedundantReshape # Import to register the legalization functions. -from . import legalize_ops, tuning_api +from . import legalize_ops diff --git a/python/tvm/relax/transform/tuning_api/__init__.py b/python/tvm/relax/transform/tuning_api/__init__.py deleted file mode 100644 index 6c39d5c5359e..000000000000 --- a/python/tvm/relax/transform/tuning_api/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin -"""Relax Tunign Pass API""" - -from .primitives import * -from .default_functions import * -from .database import * diff --git a/python/tvm/relax/transform/tuning_api/_ffi_api.py b/python/tvm/relax/transform/tuning_api/_ffi_api.py deleted file mode 100644 index 54caece700ef..000000000000 --- a/python/tvm/relax/transform/tuning_api/_ffi_api.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -"""FFI APIs for relax.tuning_api""" -import tvm.ffi - -tvm.ffi._init_api("relax.tuning_api", __name__) diff --git a/python/tvm/relax/transform/tuning_api/database.py b/python/tvm/relax/transform/tuning_api/database.py deleted file mode 100644 index cbc103423b0f..000000000000 --- a/python/tvm/relax/transform/tuning_api/database.py +++ /dev/null @@ -1,273 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relax Tuning Pass API default functions""" -from typing import List, Optional -import logging - -from tvm.runtime import Object -from tvm.ir.module import IRModule -from tvm.meta_schedule.utils import _json_de_tvm -from tvm.meta_schedule.database import Workload -from tvm.tir.schedule.trace import JSON_TYPE -from tvm.target import Target -from tvm.ffi import register_object -from .primitives import Trace -from . import _ffi_api - -logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name - - -@register_object("relax.tuning_api.TuningRecord") -class TuningRecord(Object): - """The class of tuning records. - - Parameters - ---------- - trace : tvm.relax.transform.tuning_api.Trace - The trace of the tuning record. - run_secs : Optional[List[float]] - The run-time of the tuning record. - """ - - trace: Trace - run_secs: Optional[List[float]] - - def __init__( # type: ignore # pylint: disable=too-many-arguments - self, - trace: Trace, - run_secs: Optional[List[float]] = None, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member - trace, - run_secs, - ) - - def as_json(self, include_irmod: bool = False) -> JSON_TYPE: - """Export the tuning record to a JSON string. - Parameters - ---------- - include_irmod: bool - Decides whether to serialize in_mod as well. - - Returns - ------- - json_str : str - The JSON string exported. - """ - return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self, include_irmod)) # type: ignore # pylint: disable=no-member - - @staticmethod - def from_json(json_obj: JSON_TYPE) -> "TuningRecord": - """Create a tuning record from a json object. - - Parameters - ---------- - json_obj : JSON_TYPE - The json object to parse. - - Returns - ------- - tuning_record : TuningRecord - The parsed tuning record. - """ - return _ffi_api.TuningRecordFromJSON(json_obj) # type: ignore # pylint: disable=no-member - - -@register_object("relax.tuning_api.Database") -class Database(Object): - """The abstract database interface.""" - - def has_workload(self, mod: IRModule) -> bool: - """Check if the database has the given workload. - Parameters - ---------- - mod : IRModule - The IRModule to be searched for. - - Returns - ------- - result : bool - Whether the given workload is committed. - """ - return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member - - def has_measurement_record(self, workload: Workload, target: Target) -> bool: - """Check if the database has a measurement record for the given workload and target pair. - Parameters - ---------- - workload: Workload - The workload to be searched for. - target: Target - The target to be searched for. - - Returns - ------- - result : bool - Whether the given workload and target pair is committed for the measurement record. - """ - return _ffi_api.DatabaseHasMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member - - def has_tuning_record(self, workload: Workload, target: Target) -> bool: - """Check if the database has a tuning record for the given workload and target pair. - Parameters - ---------- - workload: Workload - The workload to be searched for. - target: Target - The target to be searched for. - - Returns - ------- - result : bool - Whether the given workload and target pair is committed for the tuning record. - """ - return _ffi_api.DatabaseHasTuningRecord(self, workload, target) # type: ignore # pylint: disable=no-member - - def commit_workload(self, mod: IRModule) -> Workload: - """Commit a workload to the database if missing. - - Parameters - ---------- - mod : IRModule - The IRModule to be searched for or added. - - Returns - ------- - workload : Workload - The workload corresponding to the given IRModule. - """ - return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member - - def commit_measurement_record( - self, workload: Workload, target: Target, run_secs: List[float] - ) -> None: - """Commit a measurement record to the database. - A pair of workload and target will be used as a key. - - Parameters - ---------- - workload: Workload - The workload to be searched for. - target: Target - The target to be searched for. - run_secs : Optional[List[float]] - The measurement record to add. - """ - _ffi_api.DatabaseCommitMeasurementRecord(self, workload, target, run_secs) # type: ignore # pylint: disable=no-member - - def commit_tuning_record( - self, workload: Workload, target: Target, record: TuningRecord - ) -> None: - """Commit a tuning record to the database. - A pair of workload and target will be used as a key. - - Parameters - ---------- - workload: Workload - The workload to be searched for. - target: Target - The target to be searched for. - record : TuningRecord - The tuning record to add. - """ - _ffi_api.DatabaseCommitTuningRecord(self, workload, target, record) # type: ignore # pylint: disable=no-member - - def get_measurement_record(self, workload: Workload, target: Target) -> Optional[List[float]]: - """Get the measurement record of given workload and target from the database. - - Parameters - ---------- - workload : Workload - The workload to be searched for. - target: Target - The target to be searched for. - - Returns - ------- - measurement_record : Optional[List[float]] - Measurement record if exists. - """ - return _ffi_api.DatabaseGetMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member - - def get_top_k(self, workload: Workload, target: Target, top_k: int) -> List[TuningRecord]: - """Get the top K tuning records of given workload and target from the database. - - Parameters - ---------- - workload : Workload - The workload to be searched for. - target: Target - The target to be searched for. - top_k : int - The number of top records to get. - - Returns - ------- - top_k_records : List[TuningRecord] - The top K records. - """ - return _ffi_api.DatabaseGetTopK(self, workload, target, top_k) # type: ignore # pylint: disable=no-member - - -@register_object("relax.tuning_api.JSONDatabase") -class JSONDatabase(Database): - """The class of JSON database. - - Parameters - ---------- - path_workload : str - The path to the workload table. - path_tuning_record : str - The path to the tuning record table. - Manages pairs of - path_measurement_record : str - The path to the path_measurement_record table. - Manages pairs of - """ - - path_workload: str - path_tuning_record: str - path_measurement_record: str - - def __init__( - self, - path_workload: str, - path_tuning_record: str, - path_measurement_record: str, - allow_missing: bool = True, - ) -> None: - """Constructor. - - Parameters - ---------- - path_workload : str - The path to the workload table. - path_tuning_record : str - The path to the tuning record table. - path_measurement_record : str - The path to the path_measurement_record table. - allow_missing : bool - Whether to create new file when the given path is not found. - """ - self.__init_handle_by_constructor__( - _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member - path_workload, - path_tuning_record, - path_measurement_record, - allow_missing, - ) diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py deleted file mode 100644 index cbd71a06e608..000000000000 --- a/python/tvm/relax/transform/tuning_api/default_functions.py +++ /dev/null @@ -1,307 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relax Tuning Pass API default functions""" -from typing import Dict, List, Optional -import sys -import itertools -import logging -import numpy as np # type: ignore - -import tvm -from tvm.ir.module import IRModule -from tvm.ir.transform import PassContext, Pass -from tvm import meta_schedule -from tvm.meta_schedule.arg_info import TensorInfo -from tvm.meta_schedule.builder import BuilderInput, LocalBuilder -from tvm.meta_schedule.utils import get_global_func_with_default_on_worker -from tvm.meta_schedule.runner import ( - EvaluatorConfig, - LocalRunner, - RunnerInput, -) -from tvm.ffi.registry import register_func -from .primitives import Knob, Trace - -logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name - - -# Default transform func that returns original IRModule. -@tvm.register_func("relax.tuning_api.Choice.default_transform_func") -def default_transform_func(mod): - return mod - - -# Default constraint func that always returns true. -@tvm.register_func("relax.tuning_api.Choice.default_constr_func") -def default_constr_func(mod: IRModule) -> bool: # pylint: disable=unused-argument - return True - - -@register_func("relax.tuning_api.default_generate_candidate") -def default_generate_candidate( - knobs: List[Knob], trace: Trace, eval_passes: Optional[List[Pass]] = None -) -> List[Trace]: - """ - Default function to generate the search space for a given trace by using registered choices. - This function simply expands candidate space as long as the knob's constraint satisfies. - To reduce the search space, a developer may expand each choice with smart search method. - (e.g., genetic search, multi-armed bandit) - Note, each pass generates candidates without worrying about the interaction with other passes. - i.e., it only uses its incoming trace/IRModule and Choices for candidate generation. - This will help alleviating the complexity of joint-optimization significantly. - - consideration of interaction between optimizations has known to be extremely difficult. - - Parameters - ---------- - knobs : List[Knob] - List of Knobs to consider to generate candidate for input trace. - trace: Trace - Input trace. - eval_passes: Optional[List[Pass]] - List of passes to consider to evaluate each candidate. - This will enable joint-optimization. - - Return - ---------- - candidates: List[Trace] - List of candidate traces - """ - - candidates = [trace] - # Iterate over every decision - for knob in knobs: - num = len(candidates) - for _ in range(num): - cur_trace = candidates.pop(0) - for decision in knob.choices.keys(): - choice = knob.choices[decision] - # Generate new candidate when this condition satisfies. - if choice.check_constr(cur_trace.out_mod): - new_trace = cur_trace.deepcopy() - new_trace.add(knob, decision) - candidates.append(new_trace) - - # Expand candidates by using eval passes if provided. This will enable joint-optimization. - if eval_passes: - candidates = default_consider_eval_passes(candidates, eval_passes) - return candidates - - -@register_func("relax.tuning_api.default_consider_eval_passes") -def default_consider_eval_passes( - init_candidates: List[Trace], eval_passes: Optional[List[Pass]] = None -) -> List[Trace]: - """ - Default function to update traces with eval passes. - It visits each eval_pass in dfs order in transform.Sequential() and - returns the best possible candidate trace for each candidate. - - Parameters - ---------- - init_candidates: List[Trace] - Initial candidates - eval_passes: Optional[List[Pass]] - List of passes to consider to evaluate each candidate. - This will enable joint-optimization. - Return - ---------- - candidates: List[Trace] - List of candidate traces - """ - if not eval_passes: - return init_candidates - - eval_passes = list(eval_passes) if not isinstance(eval_passes, list) else eval_passes - ctx = PassContext.current() - candidates = [] - - for trace in init_candidates: - ctx.push_trace(trace) - tvm.transform.Sequential(eval_passes)(trace.out_mod) - new_trace = ctx.pop_trace() - # A new trace contains the best decisions in eval_passes - candidates.append(new_trace) - - return candidates - - -@register_func("relax.tuning_api.default_evaluate") -def default_evaluate( - candidates: List[Trace], - target_str: str, - params: Optional[Dict[str, np.ndarray]] = None, - builder: Optional[meta_schedule.builder.Builder] = None, - runner: Optional[meta_schedule.runner.Runner] = None, -) -> None: - """ - Default function to evaluate a set of candidate traces by using MetaSchedule builder/runner. - - Parameters - ---------- - candidates: List[Trace] - List of traces to evaluate. - target_str: str, - Compilation target (e.g., llvm, cuda). - params: Optional[Dict[str, np.ndarray]] - Params to bind. - builder: Optional[meta_schedule.builder.Builder] - builder function. If not provided, default local builder will be used. - runner: Optional[meta_schedule.runner.Runner] - runner function. If not provided, default local runner will be used. - """ - - ctx = PassContext.current() - target = tvm.target.Target(target_str) - database = PassContext.current().get_tuning_api_database() - # Setup default local builder if not provided - if builder is None: - - def relax_build( - mod: IRModule, - target: tvm.target.Target, - params: Optional[Dict[str, np.ndarray]], - ): - if params: - mod = tvm.relax.transform.BindParams("main", params)(mod) - relax_exec = tvm.compile(mod, target) - return relax_exec.mod - - builder = LocalBuilder(f_build=relax_build) - - # Setup default local runner if not provided - if runner is None: - - def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): - relax_exec = tvm.relax.VMExecutable(rt_mod) - relax_vm = tvm.relax.VirtualMachine(relax_exec, device=device) - - evaluator = relax_vm.module.time_evaluator( - func_name="main", - dev=device, - number=evaluator_config.number, - repeat=evaluator_config.repeat, - min_repeat_ms=evaluator_config.min_repeat_ms, - ) - repeated_costs: List[List[float]] = [] - for args in repeated_args: - profile_result = evaluator(*args) - repeated_costs.append(profile_result.results) - - costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] - - return costs - - runner = LocalRunner( - evaluator_config=EvaluatorConfig( - number=3, repeat=5, min_repeat_ms=100, enable_cpu_cache_flush=False - ), - f_run_evaluator=relax_eval_func, - ) - - # set up clean up function - f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) - assert f_clean_build - - # Keep track of number of evaluations (mostly for the debugging purpose) - num_evals = 0 - # Evaluation - for candidate in candidates: - # If this candidate is already evaluated, skip the measurement - if candidate.perf != -1: - continue - - # Evaluate candidates - num_evals += 1 - mod = candidate.out_mod - workload = database.commit_workload(mod) - - # If this workload and target pair has measured before, fetch its data. - if database.has_measurement_record(workload, target): - run_secs = database.get_measurement_record(workload, target) - # Otherwise, measure it. - else: - # Build candidate - (builder_result,) = builder.build([BuilderInput(mod, target, params)]) - - if builder_result.artifact_path is None: - # Build error - # Assign the worst performance and move on to the next candidate. - logger.warning(builder_result.error_msg) - run_secs = [1e100] - else: - # If build passes, set up runner input and measure the performance. - args_info = [ - TensorInfo( - shape=[int(i) for i in p.struct_info.shape], dtype=p.struct_info.dtype - ) - for p in mod["main"].params - ] # convert list[Var] to list[TensorInfo] - runner_input = RunnerInput( - builder_result.artifact_path, target_str, args_info=args_info - ) - (runner_future,) = runner.run([runner_input]) - runner_result = runner_future.result() - - run_secs = runner_result.run_secs - # Runtime error - # Assign the worst performance and move on to the next candidate. - if runner_result.error_msg is not None: - logger.warning(runner_result.error_msg) - run_secs = [1e100] - - database.commit_measurement_record(workload, target, run_secs) - - # Clean up the artifact - f_clean_build(builder_result.artifact_path) - - # For valid measurments, compute the average and update the trace performance. - perfs = [] - for result in run_secs: - if isinstance(result, tvm.tir.FloatImm): - result = result.value - assert isinstance(result, float) - assert result >= 0.0 - perfs.append(result) - - # Store the evaluation result - candidate.set_perf(np.mean(perfs)) - - ctx.inc_num_evals(num_evals) - - -def select_best_candidate(candidates: List[Trace]) -> Trace: - """ - Select the best trace. - - Parameters - ---------- - candidates: List[Trace] - Candidate traces - - Return - ---------- - best_trace: Trace - Trace with the best performance - """ - best_perf, best_trace = sys.maxsize, None - for candidate in candidates: - avg = candidate.perf - # Select best one - if best_perf > avg: - best_perf = avg - best_trace = candidate - return best_trace diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py deleted file mode 100644 index fdc3769f3e5a..000000000000 --- a/python/tvm/relax/transform/tuning_api/primitives.py +++ /dev/null @@ -1,419 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relax Tuning Pass API primitives""" - -from typing import Callable, Union, Dict, List, Optional, Sequence -import logging -import tvm -from tvm.runtime import Object -from tvm.ir.module import IRModule -from tvm.relax import Expr -from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm -from tvm.ffi import register_object -from . import _ffi_api - -logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name - - -@register_object("relax.tuning_api.Choice") -class Choice(Object): - """ - A TVM object Choice that maintains a set of transformation and constraint function keys. - Corresponding functions should be registered as PackedFunc with these keys. - Transformation function will be applied when constraint function returns true. - Parameters - ---------- - transform_func_key : Optional[str] - Key for transformation function. - transform_func_args : Optional[List] - Arguments for transformation function. - constr_func_key : Optional[str] - Key for constraint function. - constr_func_args : Optional[List] - Arguments for constraint function. - - Examples - -------- - The following code block defines a Choice. - - .. code-block:: python - @tvm.register_func("relax.tuning_api.test.transform_func") - def apply(mod): - return relax.tuning_api.FoldConstant()(mod) - @tvm.register_func("relax.tuning_api.test.constr_func") - def constr(mod): - return len(mod.functions) == 3 - # Define a choice to apply constant folding only when IRModule has three functions. - choice = Choice( - transform_func_key = "relax.tuning_api.test.transform_func", - constr_func_key = "relax.tuning_api.test.constr_func" - ) - """ - - def __init__( - self, - transform_func_key: Optional[str] = None, - transform_func_args: Optional[List] = None, - constr_func_key: Optional[str] = None, - constr_func_args: Optional[List] = None, - ): - """Constructor - Parameters - ---------- - transform_func_key : Optional[str] - Key for transformation function. - - f_tramsform_args: Optional[List] - Arguments for transformation function. - - constr_func_key : Optional[str] - Key for constraint function. - - constr_func_args: Optional[List] - Arguments for constraint function. - """ - - if transform_func_key is None: - transform_func_key = "relax.tuning_api.Choice.default_transform_func" - - if transform_func_args is None: - transform_func_args = [] - - if constr_func_key is None: - constr_func_key = "relax.tuning_api.Choice.default_constr_func" - - if constr_func_args is None: - constr_func_args = [] - - self.__init_handle_by_constructor__( - _ffi_api.Choice, # type: ignore - transform_func_key, - transform_func_args, - constr_func_key, - constr_func_args, # type: ignore # pylint: disable=no-member - ) - - def get_transform_func(self) -> Callable: - """Getter for transform_func - Returns - ------- - ret: Callable - registered transformation function - """ - return _ffi_api.ChoiceGetTransformFunc(self) # type: ignore - - def get_constr_func(self) -> Callable: - """Getter for constr_func - Returns - ------- - ret: Callable - registered constraint function - """ - return _ffi_api.ChoiceGetConstrFunc(self) # type: ignore - - def apply_transform_func(self, mod: IRModule) -> IRModule: - """Perform transform_func with its arguments - Returns - ------- - ret: IRModule - Transformed IRModule - """ - return _ffi_api.ChoiceApplyTransformFunc(self, mod) # type: ignore - - def check_constr(self, mod: IRModule) -> bool: - """Perform constr_func with its arguments - Returns - ------- - ret: bool - Returns whether the IRModule satisfies the constraint or not - """ - return _ffi_api.ChoiceCheckConstr(self, mod) # type: ignore - - def as_json(self) -> JSON_TYPE: - """Serialize the trace as a JSON-style object - Returns - ------- - json: JSON_TYPE - The JSON-style object - """ - return _ffi_api.ChoiceAsJSON(self) # type: ignore # pylint: disable=no-member - - @staticmethod - def from_json(json_obj: JSON_TYPE) -> "Choice": - """Create Choice from JSON obj - - Parameters - ---------- - json_obj: JSON_TYPE - Choice serialized with JSON - - Return - ---------- - choice: Choice - Deserialized choice - """ - return _ffi_api.ChoiceFromJSON(json_obj) # type: ignore - - def deepcopy(self): - return Choice.from_json(self.as_json()) - - -@register_object("relax.tuning_api.Knob") -class Knob(Object): - """ - A TVM object Knob that maintains a set of valid Choices. - By using Knobs, a tuning pass can generate candidates and define the search space. - Parameters - ---------- - name : str - Name of the knob. - - choices: Union[List[Choice], Dict[str, Choice]] - A list of valid choices - - Examples - -------- - The following code block defines a Knob. - - .. code-block:: python - @tvm.register_func("relax.tuning_api.test.transform_func") - def apply(mod): - return relax.tuning_api.FoldConstant()(mod) - choices = {"apply": Choice("relax.tuning_api.test.transform_func"), "noapply": Choice()} - # A knob manages a set of its valid choices - knob = Knob("MockTuningKnob", choices) - """ - - def __init__(self, name: str, choices: Union[List[Choice], Dict[str, Choice]]): - """Constructor.""" - if isinstance(choices, list): - choices = {str(idx): val for idx, val in enumerate(choices)} - - self.__init_handle_by_constructor__( - _ffi_api.Knob, name, choices # type: ignore # pylint: disable=no-member - ) - - def verify(self, decision: Union[str, int]) -> bool: - """Verify if the decision is valid.""" - if isinstance(decision, int): - decision = str(decision) - return _ffi_api.KnobIsValidDecision(self, decision) # type: ignore - - def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule: - """Get choice if a decision is valid.""" - if isinstance(decision, int): - decision = str(decision) - return _ffi_api.KnobApply(self, mod, decision) # type: ignore - - def as_json(self) -> JSON_TYPE: - """Serialize the trace as a JSON-style object - Returns - ------- - json: JSON_TYPE - The JSON-style object - """ - return _ffi_api.KnobAsJSON(self) # type: ignore - - @staticmethod - def from_json(json_obj: JSON_TYPE) -> "Knob": - """Create Knob from JSON obj - - Parameters - ---------- - json_obj: JSON_TYPE - Knob serialized with JSON - - Return - ---------- - knob: Knob - Deserialized knob - """ - return _ffi_api.KnobFromJSON(json_obj) # type: ignore - - def __str__(self) -> str: - msg = f"{self.name} (# of choices: {len(self.choices)})\n" - for name, choice in self.choices.items(): - msg += f" - {name}: {choice}\n" - return msg - - def deepcopy(self): - return Knob.from_json(self.as_json()) - - -@register_object("relax.tuning_api.Trace") -class Trace(Object): - """ - A TVM object Trace logs the history of transformations (decisions). - Parameters - ---------- - in_mod : IRModule - Input IRModule. - knobs: Optional[List[Knob]] - A list of knobs applied in the trace. - decisions: Optional[Sequence[Union[str, int]]] - A list of decisions made for each knob - - Examples - -------- - The following code block defines a Trace. - - .. code-block:: python - - trace = Trace(mod, [knob1, knob2, knob3], ["c1", "c0", "c3"]) - assert trace.size == 3 # Length of history. - # 'out' contains IRModule that applies transformations in the trace. - out: IRModule = trace.add(knob4, "c2") - assert trace.size == 4 # Length of history. - trace.set_perf(0.03) # Set the performance number of the trace. - """ - - def __init__( - self, - in_mod: IRModule, - knobs: Optional[List[Knob]] = None, - decisions: Optional[Sequence[Union[str, int]]] = None, - ): - """Constructor.""" - knobs = knobs if knobs else list() - decisions = ( - [str(v) if isinstance(v, int) else v for v in decisions] if decisions else list() - ) - self.__init_handle_by_constructor__( - _ffi_api.Trace, in_mod, knobs, decisions # type: ignore # pylint: disable=no-member - ) - - def verify(self) -> bool: - """Verify if current history is valid.""" - return _ffi_api.TraceVerify() # type: ignore - - def add(self, knob: Knob, decision: Union[str, int]) -> IRModule: - """Add & Apply new decision (with knob).""" - if isinstance(decision, int): - decision = str(decision) - return _ffi_api.TraceAdd(self, knob, decision) # type: ignore - - def set_perf(self, perf: float) -> None: - """Set performance number for the trace.""" - return _ffi_api.TraceSetPerf(self, perf) # type: ignore - - def set_out_mod(self, mod: IRModule) -> None: - """Set out_mod for the trace.""" - return _ffi_api.TraceSetOutMod(self, mod) # type: ignore - - def as_json(self, include_irmod: bool = True) -> JSON_TYPE: - """Serialize the trace as a JSON-style object. - Parameters - ---------- - include_irmod: bool - Decides whether to serialize in_mod as well. - - Returns - ------- - json: JSON_TYPE - The JSON-style object. - """ - obj = _ffi_api.TraceAsJSON(self, include_irmod) # type: ignore - return _json_from_tvm(obj) - - @staticmethod - def from_json(json_obj: JSON_TYPE) -> "Trace": - """Create Trace from JSON obj. - - Parameters - ---------- - json_obj: JSON_TYPE - Trace serialized with JSON. - - Return - ---------- - trace: Trace - Deserialized trace. - """ - return _ffi_api.TraceFromJSON(json_obj) # type: ignore - - def __str__(self) -> str: - n = len(self.knobs) - msg = f"Trace length: {n}\n" - for idx in range(n): - msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n" - return msg - - def deepcopy(self) -> "Trace": - new_in_mod = deepcopy_irmodule(self.in_mod) - new_knobs = [knob.deepcopy() for knob in self.knobs] - new_decisions = [str(decision) for decision in self.decisions] - new_trace = Trace(new_in_mod, new_knobs, new_decisions) - new_out_mod = deepcopy_irmodule(self.out_mod) - new_trace.set_out_mod(new_out_mod) - return new_trace - - -def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: - """ - Getter for a trace wrapper. - - Parameters - ---------- - in_: Union[Trace, IRModule, Expr] - Input entity - Return - ---------- - wrapped: Trace - Traced entity - """ - if isinstance(in_, Trace): - return in_ - if isinstance(in_, IRModule): - return Trace(in_) - if isinstance(in_, Expr): # type: ignore - return Trace(tvm.IRModule.from_expr(in_)) - - raise Exception(f"Invalid input type for trace: {type(in_)}") - - -@tvm.register_func("relax.tuning_api.deepcopy_irmodule") -def deepcopy_irmodule(mod: IRModule) -> IRModule: - """ - Deepcopy for an IRModule. - Parameters - ---------- - mod: IRModule - input IRModule - Return - ---------- - copied_mod: IRModule - deep-copied IRModule - """ - func_save_json = tvm.get_global_func("node.SaveJSON") - func_load_json = tvm.get_global_func("node.LoadJSON") - new_mod = None - # Handle external modules separately if exist - # TODO(tvm-team): - # Serialization of IRModule with external mods is tricky. - # (1) External mod is runtime module. - # (2) Currently, `export_library` does not support serialization of - # runtime module without the host module - # Therefore, we simply pass around the compiled external modules without copy for now. - # Revisit later when we have a better solution. - if mod.attrs and "external_mods" in mod.attrs: - tmp_mod = mod.without_attr("external_mods") - new_mod = func_load_json(func_save_json(tmp_mod)) - new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"]) - else: - new_mod = func_load_json(func_save_json(mod)) - - return new_mod diff --git a/src/ir/transform.cc b/src/ir/transform.cc index db4e47ca0d1a..3e8a9aa6ee51 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -485,40 +484,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c mod = GetPass(it)(std::move(mod), pass_ctx); } - // This handles passes that does not use Relax tuning API (untraceable passes). - // We make untraceable passes trackable when pass context has a trace (trace mode). - // When passes to trace (make_traceable) is provided from users, we only make them trackable. - if (pass_ctx->trace_stack.size() && !pass_info->traceable && - (!pass_ctx->make_traceable.defined() || - pass_ctx->make_traceable.value().count(pass_info->name))) { - // TODO(tvm-team): Currently, there are some inconsistency in the pass registration. - // 1. Some passes are not registered in ffi registry. - // 2. Some passes do not follow the name convention. (e.g., = + ) - - // Due to these problems, serialization with non-traceable passes is handled in a hacky way - // now. Find a systematic way to identify such inconsistencies and fix them. - - // In the future, we should pass the ffi key for a pass by deducing from its name. - String transform_func_key = "relax.tuning_api.Choice.default_transform_func"; - String constr_func_key = "relax.tuning_api.Choice.default_constr_func"; - - relax::Knob knob = relax::Knob( - pass_info->name, {{"Applied", relax::Choice(transform_func_key, Array(), - constr_func_key, Array())}}); - - // Add new decision to the trace at the top of the stack. - auto trace = Downcast(pass_ctx->trace_stack.back()); - trace->Add(knob, "Applied"); - // In the future, we should just have - // mod = trace->Add(knob, "enabled"); - // instead of the two lines below. - mod = pass(std::move(mod), pass_ctx); - trace->SetOutMod(mod); - - } else { - mod = pass(std::move(mod), pass_ctx); - } + mod = pass(std::move(mod), pass_ctx); } return mod; } @@ -614,9 +580,7 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_FFI_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, - Optional> config, Array trace_stack, - Optional> make_traceable, int num_evals, - Optional tuning_api_database) { + Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -627,10 +591,6 @@ TVM_FFI_REGISTER_GLOBAL("transform.PassContext") if (config.defined()) { pctx->config = config.value(); } - pctx->trace_stack = std::move(trace_stack); - pctx->make_traceable = std::move(make_traceable); - pctx->num_evals = std::move(num_evals); - pctx->tuning_api_database = std::move(tuning_api_database); PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); @@ -647,7 +607,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\tinstruments: " << node->instruments << "\n"; p->stream << "\tconfig: " << node->config << "\n"; - p->stream << "\ttrace stack: " << node->trace_stack; }); class PassContext::Internal { @@ -657,18 +616,6 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_FFI_REGISTER_GLOBAL("transform.GetTraceStack").set_body_method(&PassContextNode::GetTraceStack); -TVM_FFI_REGISTER_GLOBAL("transform.PushTrace").set_body_method(&PassContextNode::PushTrace); -TVM_FFI_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); -TVM_FFI_REGISTER_GLOBAL("transform.GetTraceStackSize") - .set_body_method(&PassContextNode::GetTraceStackSize); -TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentTrace") - .set_body_method(&PassContextNode::GetCurrentTrace); -TVM_FFI_REGISTER_GLOBAL("transform.SetNumEvals").set_body_method(&PassContextNode::SetNumEvals); -TVM_FFI_REGISTER_GLOBAL("transform.IncNumEvals").set_body_method(&PassContextNode::IncNumEvals); -TVM_FFI_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") - .set_body_method(&PassContextNode::GetTuningAPIDatabase); - TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); TVM_FFI_REGISTER_GLOBAL("transform.EnterPassContext") diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index cf7b9fc03a50..08e5a100ab22 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include "../src/meta_schedule/module_equality.h" @@ -44,46 +43,21 @@ class MetaScheduleTuner { max_trials_per_task_(max_trials_per_task), op_names_(op_names), params_(params) { - normalize_mod_func_ = tvm::ffi::Function::GetGlobal("tvm.meta_schedule.normalize_mod"); - ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; + normalize_mod_func_ = tvm::ffi::Function::GetGlobalRequired("tvm.meta_schedule.normalize_mod"); } - // TODO(@sunggg): Currently, only supports basic arguments. IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { - Choice choice( - "tvm.meta_schedule.tune_relax", - {params_, target_, work_dir_, max_trials_global_, max_trials_per_task_, op_names_}, - "relax.tuning_api.Choice.default_constr_func", {}); - Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); - knob->Apply(mod, "0"); - /* - // TODO(@sunggg): revisit when we have a solution for large params - Trace trace = Downcast(ctx->GetCurrentTrace()); - ctx->PopTrace(); - Array candidates = (*candgen_func_)(Array({knob}), trace); - ICHECK(candidates.size() == 1); - Trace best_trace = candidates[0]; - ctx->PushTrace(best_trace); - */ - // since we separate tuning from application, return original IRModule + static ffi::Function tune_relax_func = + tvm::ffi::Function::GetGlobalRequired("tvm.meta_schedule.tune_relax"); + tune_relax_func(mod, params_, target_, work_dir_, max_trials_global_, max_trials_per_task_, + op_names_); return mod; } - // TODO(@sunggg): Currently, only supports basic arguments. tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { - // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace - // stack. Revisit later when we collect more usecases. - Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_}, - "relax.tuning_api.Choice.default_constr_func", {}); - Knob knob("meta_schedule.tune_primfunc", {{"0", choice}}); - knob->Apply((*normalize_mod_func_)(f).cast(), "0"); - /* - // TODO(@sunggg): revisit when we have a solution for large params - Trace trace = Trace((*normalize_mod_func_)(f), {}, {}); - Array candidates = (*candgen_func_)(Array({knob}), trace); - ICHECK(candidates.size() == 1); - */ - // since we separate tuning from application, return original IRModule + static ffi::Function tune_tir_func = + tvm::ffi::Function::GetGlobalRequired("tvm.meta_schedule.tune_tir"); + tune_tir_func(normalize_mod_func_(f), target_, work_dir_, max_trials_global_); return f; } @@ -94,7 +68,7 @@ class MetaScheduleTuner { Integer max_trials_per_task_; Optional> op_names_; Map params_; - std::optional normalize_mod_func_; + tvm::ffi::Function normalize_mod_func_; }; Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = false) { diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc deleted file mode 100644 index fedc61019b06..000000000000 --- a/src/relax/transform/tuning_api/database.cc +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relax/transform/tuning_api/database.cc - * \brief Database of tuning APIs. - */ -#include - -#include -#include -#include - -#include "../../../meta_schedule/utils.h" - -namespace tvm { -namespace relax { - -TuningRecord::TuningRecord(Trace trace, Optional> run_secs) { - ObjectPtr n = make_object(); - n->trace = trace; - n->run_secs = run_secs; - this->data_ = n; -} - -ObjectRef TuningRecordNode::AsJSON(bool include_irmod) const { - LOG(INFO) << "TuningRecordNode::AsJSON " << AsLegacyRepr(trace->AsJSON(include_irmod)); - return Array{trace->AsJSON(include_irmod), // - run_secs}; -} - -TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { - Trace trace{nullptr}; - Optional> run_secs{nullptr}; - try { - const ffi::ArrayObj* json_array = json_obj.as(); - CHECK(json_array && json_array->size() == 2); - // Load json[0] => trace - { - const ObjectRef& json_trace = json_array->at(0).cast(); - trace = Trace::FromJSON(json_trace); - } - - // Load json[1] => run_secs - if (json_array->at(1) != nullptr) { - run_secs = meta_schedule::AsFloatArray(json_array->at(1).cast()); - } - } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); - } - return TuningRecord(trace, run_secs); -} - -/*! \brief The struct defining comparison function of sorting by mean run seconds. */ -struct SortTuningRecordByMeanRunSecs { - static const constexpr double kMaxMeanTime = 1e10; - - static double Mean(const Array& a) { - if (a.empty()) { - return kMaxMeanTime; - } - double sum = 0.0; - for (const FloatImm& i : a) { - sum += i->value; - } - return sum / a.size(); - } - - bool operator()(const TuningRecord& a, const TuningRecord& b) const { - double a_time = Mean(a->run_secs.value_or({})); - double b_time = Mean(b->run_secs.value_or({})); - return a_time < b_time; - } -}; - -// TODO(tvm-team): Currently, we strictly treat each target separately. -// Since not every option in the target matters, this might be the overkill. -// Revisit this when we have better approach with target equality check. -inline std::string get_database_key(int workload_idx, Target target) { - return std::to_string(workload_idx) + "/" + target->str(); -} - -/*! \brief The default database implementation, which mimics two database tables with two files. - */ -class JSONDatabaseNode : public DatabaseNode { - public: - /*! \brief The path to the workload table */ - String path_workload; - /*! \brief The path to the tuning record table */ - String path_tuning_record; - /*! \brief The path to the measurement table */ - String path_measurement_record; - /*! \brief All the workloads in the database */ - std::unordered_map - workloads2idx_; - /*! \brief All the tuning records in the database */ - std::unordered_map> - tuning_records_; - - /*! \brief Measurement logs in the database */ - std::unordered_map> measurement_records_; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("path_workload", &path_workload); - v->Visit("path_tuning_record", &path_tuning_record); - v->Visit("path_measurement_record", &path_measurement_record); - // `workloads2idx_` is not visited - // `tuning_records_` is not visited - // `measurement_records_` is not visited - } - - static constexpr const char* _type_key = "relax.tuning_api.JSONDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); - - public: - bool HasWorkload(const IRModule& mod) { - return workloads2idx_.find(meta_schedule::Workload(mod, tvm::StructuralHash()(mod))) != - workloads2idx_.end(); - } - - bool HasMeasurementRecord(const meta_schedule::Workload& workload, const Target& target) { - int workload_idx = this->workloads2idx_.at(workload); - std::string key = get_database_key(workload_idx, target); - return measurement_records_.count(key) > 0; - } - - bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) { - int workload_idx = this->workloads2idx_.at(workload); - std::string key = get_database_key(workload_idx, target); - return tuning_records_.count(key) > 0; - } - - meta_schedule::Workload CommitWorkload(const IRModule& mod) { - // Try to insert `mod` into `workloads_` - decltype(this->workloads2idx_)::iterator it; - bool inserted = false; - std::tie(it, inserted) = - this->workloads2idx_.emplace(meta_schedule::Workload(mod, tvm::StructuralHash()(mod)), -1); - meta_schedule::Workload workload = it->first; - // If `mod` is new in `workloads2idx_`, append it to the workload file - if (inserted) { - it->second = static_cast(this->workloads2idx_.size()) - 1; - meta_schedule::JSONFileAppendLine(this->path_workload, - meta_schedule::JSONDumps(workload->AsJSON())); - } - return it->first; - } - - void CommitMeasurementRecord(const meta_schedule::Workload& workload, const Target& target, - const Array& run_secs) { - int workload_idx = this->workloads2idx_.at(workload); - std::string key = get_database_key(workload_idx, target); - - if (measurement_records_[key].size() == 0) { - measurement_records_[key] = run_secs; - meta_schedule::JSONFileAppendLine(this->path_measurement_record, - meta_schedule::JSONDumps(Array{ - Integer(workload_idx), target->Export(), - run_secs // - })); - } else { - LOG(WARNING) << "Measurement record for " << key - << " already exists. Use the existing one instead."; - } - } - - void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, - const TuningRecord& record) { - int workload_idx = this->workloads2idx_.at(workload); - // There may exist multiple tuning records (with different traces) for a single key pair. - std::string key = get_database_key(workload_idx, target); - this->tuning_records_[key].insert(record); - - meta_schedule::JSONFileAppendLine( - this->path_tuning_record, meta_schedule::JSONDumps(Array{ - Integer(workload_idx), target->Export(), record->AsJSON()})); - } - - Array GetTopK(const meta_schedule::Workload& workload, const Target& target, - int top_k) { - CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; - if (top_k == 0) { - return {}; - } - Array results; - results.reserve(top_k); - int counter = 0; - int idx = this->workloads2idx_.at(workload); - std::string key = get_database_key(idx, target); - for (const TuningRecord& record : this->tuning_records_[key]) { - results.push_back(record); - if (++counter == top_k) { - break; - } - } - - return results; - } - - Array GetMeasurementRecord(const meta_schedule::Workload& workload, - const Target target) { - int workload_idx = this->workloads2idx_.at(workload); - return this->measurement_records_[get_database_key(workload_idx, target)]; - } -}; - -Database Database::JSONDatabase(String path_workload, String path_tuning_record, - String path_measurement_record, bool allow_missing) { - int num_threads = std::thread::hardware_concurrency(); - ObjectPtr n = make_object(); - // Load `n->workloads2idx_` from `path_workload` - std::vector workloads; - { - std::vector json_objs = - meta_schedule::JSONFileReadLines(path_workload, num_threads, allow_missing); - int n_objs = json_objs.size(); - n->workloads2idx_.reserve(n_objs); - workloads.reserve(n_objs); - for (int i = 0; i < n_objs; ++i) { - meta_schedule::Workload workload = - meta_schedule::Workload::FromJSON(json_objs[i].cast()); - n->workloads2idx_.emplace(workload, i); - workloads.push_back(workload); - } - } - // Load `n->tuning_records_` from `path_tuning_record` - { - std::vector json_objs = - meta_schedule::JSONFileReadLines(path_tuning_record, num_threads, allow_missing); - - std::vector workload_idxs; - std::vector targets; - std::vector records; - int size = json_objs.size(); - workload_idxs.resize(size, -1); - targets.resize(size, Target{nullptr}); - records.resize(size, TuningRecord{nullptr}); - support::parallel_for_dynamic( - 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { - const ObjectRef& json_obj = json_objs[task_id].cast(); - try { - const ffi::ArrayObj* arr = json_obj.as(); - ICHECK_EQ(arr->size(), 3); - workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); - targets[task_id] = Target(Downcast>(arr->at(1))); - records[task_id] = TuningRecord::FromJSON(arr->at(2).cast()); - } catch (std::runtime_error& e) { - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); - } - }); - - for (int i = 0; i < size; i++) { - std::string key = get_database_key(workload_idxs[i], targets[i]); - n->tuning_records_[key].insert(records[i]); - } - } - - // Load `n->measuremet_log` from `path_measurement_record` - { - std::vector json_objs = - meta_schedule::JSONFileReadLines(path_measurement_record, num_threads, allow_missing); - std::vector workload_idxs; - std::vector targets; - std::vector> measurements; - int size = json_objs.size(); - workload_idxs.resize(size, -1); - targets.resize(size, Target{nullptr}); - measurements.resize(size, Array({})); - support::parallel_for_dynamic( - 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { - const ObjectRef& json_obj = json_objs[task_id].cast(); - try { - const ffi::ArrayObj* arr = json_obj.as(); - ICHECK_EQ(arr->size(), 3); - workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); - targets[task_id] = Target(Downcast>(arr->at(1))); - measurements[task_id] = meta_schedule::AsFloatArray(arr->at(2).cast()); - } catch (std::runtime_error& e) { - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); - } - }); - for (int i = 0; i < size; i++) { - n->measurement_records_[get_database_key(workload_idxs[i], targets[i])] = measurements[i]; - } - } - - n->path_workload = path_workload; - n->path_tuning_record = path_tuning_record; - n->path_measurement_record = path_measurement_record; - return Database(n); -} - -/**************** FFI ****************/ -TVM_REGISTER_NODE_TYPE(TuningRecordNode); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") - .set_body_typed([](Trace trace, Optional> run_secs) { - return TuningRecord(trace, run_secs); - }); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") - .set_body_method(&TuningRecordNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON") - .set_body_typed(TuningRecord::FromJSON); - -TVM_REGISTER_OBJECT_TYPE(DatabaseNode); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") - .set_body_method(&DatabaseNode::HasWorkload); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") - .set_body_method(&DatabaseNode::HasMeasurementRecord); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") - .set_body_method(&DatabaseNode::HasTuningRecord); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") - .set_body_method(&DatabaseNode::CommitMeasurementRecord); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") - .set_body_method(&DatabaseNode::CommitWorkload); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") - .set_body_method(&DatabaseNode::CommitTuningRecord); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") - .set_body_method(&DatabaseNode::GetMeasurementRecord); - -TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase") - .set_body_typed(Database::JSONDatabase); -} // namespace relax -} // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc deleted file mode 100644 index 5f53b5166725..000000000000 --- a/src/relax/transform/tuning_api/primitives.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relax/transform/tuning_api/primitives.cc - * \brief Primitives of tuning APIs. - */ - -#include - -#include "../../../meta_schedule/utils.h" -namespace tvm { -namespace relax { - -Choice::Choice(String transform_func_key, Array transform_func_args, String constr_func_key, - Array constr_func_args) { - ObjectPtr n = make_object(); - n->transform_func_key = std::move(transform_func_key); - n->transform_func_args = std::move(transform_func_args); - n->constr_func_key = std::move(constr_func_key); - n->constr_func_args = std::move(constr_func_args); - data_ = std::move(n); -} - -// TODO(sunggg): Currently, it only supports an array of primitive data types. -ObjectRef ChoiceNode::AsJSON() const { - Array json_transfrom_args, json_constr_args; - for (Any arg : this->transform_func_args) { - std::string json_arg = tvm::SaveJSON(arg); - std::string b64_arg = meta_schedule::Base64Encode(json_arg); - json_transfrom_args.push_back(String(b64_arg)); - } - for (Any arg : this->constr_func_args) { - std::string json_arg = tvm::SaveJSON(arg); - std::string b64_arg = meta_schedule::Base64Encode(json_arg); - json_constr_args.push_back(String(b64_arg)); - } - return Array{ - this->transform_func_key, - json_transfrom_args, - this->constr_func_key, - json_constr_args, - }; -} - -Choice Choice::FromJSON(const ObjectRef& json) { - // Parse `json` into `choice` - String transform_func_key, constr_func_key; - Array transform_func_args, constr_func_args; - try { - const ffi::ArrayObj* arr = json.as(); - ICHECK(arr && arr->size() == 4); - const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); - const auto* arr2 = arr->at(2).as(); - const auto* arr3 = arr->at(3).as(); - ICHECK(arr0 && arr1 && arr2 && arr3); - transform_func_key = GetRef(arr0); - { - transform_func_args.reserve(arr1->size()); - for (const Any& elem : *arr1) { - String b64_arg = Downcast(elem); - std::string json_arg = meta_schedule::Base64Decode(b64_arg); - Any arg = LoadJSON(json_arg); - transform_func_args.push_back(arg); - } - } - constr_func_key = GetRef(arr2); - { - constr_func_args.reserve(arr3->size()); - for (const Any& elem : *arr3) { - String b64_arg = Downcast(elem); - std::string json_arg = meta_schedule::Base64Decode(b64_arg); - Any arg = LoadJSON(json_arg); - constr_func_args.push_back(arg); - } - } - } catch (const tvm::Error& e) { - LOG(FATAL) - << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " - << json; - throw; - } - return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); -} - -Knob::Knob(String name, Map choices) { - ObjectPtr n = make_object(); - n->name = std::move(name); - n->choices = std::move(choices); - data_ = std::move(n); -} - -ObjectRef KnobNode::AsJSON() const { - Map json_choices; - for (auto const& x : choices) { - json_choices.Set(x.first, x.second->AsJSON()); - } - return Array{ - /* 0: name */ std::move(name), - /* 1: choices */ std::move(json_choices), - }; -} - -Knob Knob::FromJSON(const ObjectRef& json) { - // Parse `json` into `name` and `choices` - String name; - Map choices; - try { - const ffi::ArrayObj* arr = json.as(); - ICHECK(arr && arr->size() == 2); - const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); - ICHECK(arr0 && arr1); - name = GetRef(arr0); - for (auto const& x : GetRef>(arr1)) { - String decision = x.first; - Choice choice = Choice::FromJSON(x.second.cast()); - choices.Set(decision, choice); - } - } catch (const tvm::Error& e) { - LOG(FATAL) - << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " - << json; - throw; - } - return Knob(name, choices); -} - -Trace::Trace() { data_ = make_object(); } - -Trace::Trace(IRModule in_mod, Array knobs, Array decisions) { - ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match"; - // Deep-copy IRModule - const auto func_deepcopy = - tvm::ffi::Function::GetGlobalRequired("relax.tuning_api.deepcopy_irmodule"); - IRModule out_mod = func_deepcopy(in_mod).cast(); - // Apply the decision history if provided - int size = knobs.size(); - for (int i = 0; i < size; i++) { - out_mod = knobs[i]->Apply(out_mod, decisions[i]); - } - - ObjectPtr n = make_object(); - n->in_mod = std::move(in_mod); - n->out_mod = std::move(out_mod); - n->knobs = std::move(knobs); - n->decisions = std::move(decisions); - n->size = std::move(size); - data_ = std::move(n); -} - -ObjectRef TraceNode::AsJSON(bool include_in_mod) const { - ICHECK(this->Verify()) << "Trace should be valid"; - - Array json_knobs; - Array json_decisions; - - int size = this->size; - json_knobs.reserve(size); - json_decisions.reserve(size); - - for (int i = 0; i < size; i++) { - const Knob& knob = this->knobs[i]; - const String& decision = this->decisions[i]; - - json_knobs.push_back(knob->AsJSON()); - json_decisions.push_back(decision); - } - if (include_in_mod) { - std::string json_mod = tvm::SaveJSON(this->in_mod); - std::string b64_mod = meta_schedule::Base64Encode(json_mod); - return Array{json_knobs, json_decisions, String(b64_mod)}; - } else { - return Array{json_knobs, json_decisions}; - } -} - -Trace Trace::FromJSON(const ObjectRef& json) { - // Parse `json` into `trace` - IRModule in_mod; - Array knobs; - Array decisions; - try { - const ffi::ArrayObj* arr = json.as(); - // A trace will have 2 or 3 entries depending on `include_irmod` parameter. - ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); - - const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); - ICHECK(arr0 && arr1); - - for (const Any& elem : *arr0) { - knobs.push_back(Knob::FromJSON(elem.cast())); - } - - for (const Any& elem : *arr1) { - decisions.push_back(Downcast(elem)); - } - - // When `include_irmod = true` - if (arr->size() == 3) { - const auto* arr2 = arr->at(2).as(); - String b64_mod = GetRef(arr2); - ICHECK(arr2); - std::string json_mod = meta_schedule::Base64Decode(b64_mod); - in_mod = Downcast(LoadJSON(json_mod)); - } - } catch (const tvm::Error& e) { - LOG(FATAL) << "ValueError: Malformed Trace format - " << json; - throw; - } - return Trace(in_mod, knobs, decisions); -} - -/**************** FFI ****************/ -TVM_REGISTER_NODE_TYPE(ChoiceNode); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Choice") - .set_body_typed([](String transform_func_key, Array transform_func_args, - String constr_func_key, Array constr_func_args) { - return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); - }); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") - .set_body_method(&ChoiceNode::GetTransformFunc); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") - .set_body_method(&ChoiceNode::GetConstrFunc); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") - .set_body_method(&ChoiceNode::ApplyTransformFunc); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") - .set_body_method(&ChoiceNode::CheckConstr); - -TVM_REGISTER_NODE_TYPE(KnobNode); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Knob") - .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") - .set_body_method(&KnobNode::IsValidDecision); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); - -TVM_REGISTER_NODE_TYPE(TraceNode); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Trace") - .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { - return Trace(in_mod, knobs, decisions); - }); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod").set_body_method(&TraceNode::SetOutMod); - -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); -TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); -} // namespace relax -} // namespace tvm diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py index bb5a04ef7679..a8db5d77889f 100644 --- a/tests/python/relax/conftest.py +++ b/tests/python/relax/conftest.py @@ -80,10 +80,6 @@ def apply_instrument_well_formed(unit_test_marks): required_pass=current.required_pass, disabled_pass=current.disabled_pass, config=current.config, - trace_stack=current.trace_stack, - make_traceable=current.make_traceable, - num_evals=current.num_evals, - tuning_api_database=current.get_tuning_api_database(), ) with override: yield diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index e053512439af..7560246b8af4 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -26,7 +26,6 @@ from tvm import relax, tir from tvm.relax.dpl import is_op, wildcard from tvm.relax.testing import transform -from tvm.relax.transform.tuning_api import Trace from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T @@ -164,7 +163,7 @@ def test_mix_use_tensorrt_and_tvm(): # Run Codegen pass with tempfile.TemporaryDirectory() as work_dir: - with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0): + with target, tvm.transform.PassContext(opt_level=0): new_mod = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern(patterns), diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index 15657527cdbb..3d290c0ae8c6 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -25,8 +25,6 @@ from tvm.ir.module import IRModule from tvm.ir.transform import PassContext -# TODO(@sunggg): re-enable Trace when we have a solution for large params -# from tvm.relax.transform.tuning_api import Trace from tvm.script import relax as R from tvm.script import tir as T @@ -211,7 +209,6 @@ def test_ms_database_apply_fallback(): assert isinstance(mod, IRModule) with tempfile.TemporaryDirectory() as work_dir: """ - # TODO(@sunggg): Revisit when ready with target_cuda, PassContext(trace=Trace(mod), opt_level=0): tuning_pass = relax.transform.MetaScheduleTuneTIR( work_dir=work_dir, max_trials_global=0 diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py deleted file mode 100644 index 082c9ce16a30..000000000000 --- a/tests/python/relax/test_tuning_api.py +++ /dev/null @@ -1,781 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import numpy as np -import os.path as osp -import tempfile -from typing import List -from math import isclose - -import tvm -from tvm import ir -from tvm.ir import transform -from tvm.ir.transform import PassContext -from tvm.ir.module import IRModule -from tvm.script import tir as T, relax as R -from tvm import relax -from tvm.relax.expr import Expr, DataflowBlock, Function -from tvm.relax.transform.tuning_api import ( - Choice, - Knob, - Trace, - TuningRecord, - JSONDatabase, - default_generate_candidate, - default_consider_eval_passes, - default_evaluate, - select_best_candidate, - get_trace, -) - - -@tvm.script.ir_module -class TestModule: - @T.prim_func - def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: - T.func_attr(({"global_symbol": "addone"})) - for i, j in T.grid(16, 16): - with T.block("addone"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] + T.int32(1) - - # Input IRModule. - @R.function - def before(c0: R.Tensor((16, 16), "int32")): - cls = TestModule - lv0 = R.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="int32")) - return lv0 - - # Expected IRModule after transformation. - @R.function - def expected(c1: R.Tensor((16, 16), "int32")): - return c1 - - -def gen_mod(mod, name, binding): - funcs = {} - binding = {k: tvm.nd.array(v) for k, v in binding.items()} - - for k, v in mod.functions.items(): - if isinstance(v, tvm.relax.Function): - if k.name_hint == name: - # rename to main. - gv = tvm.ir.GlobalVar("main") - funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( - "global_symbol", "main" - ) - else: - funcs[k] = v - mod = tvm.IRModule(funcs) - return relax.transform.BindParams("main", binding)(mod) - - -# Setup for simple testing with IRModule. -def setup_test(): - mod = TestModule - assert isinstance(mod, tvm.IRModule) - return gen_mod(mod, "before", {}) - - -# Setup for testing with constant folding. -def setup_test_const_folding(): - mod = TestModule - assert isinstance(mod, tvm.IRModule) - # Test setup. - c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) - c1_np = c0_np + 1 - before = gen_mod(mod, "before", {"c0": c0_np}) - expected = gen_mod(mod, "expected", {"c1": c1_np}) - - return before, expected - - -# Define a choice by using FoldConstant pass. -@tvm.register_func("testing.apply_fold_constant") -def apply_fold_constant(mod): - return relax.transform.FoldConstant()(mod) - - -@tvm.register_func("testing.add_global_symbol") -def add_global_symbol(mod, func_name, global_symbol): - mod[func_name] = mod[func_name].with_attr("global_symbol", global_symbol) - return mod - - -@tvm.register_func("testing.check_num_functions") -def check_num_funcs(mod, N): - # Explicit type specification is necessary. - # Otherwise, PackedFunc cannot derive the return type correctly. - # e.g., Check failed: type_code_ == kDLInt (8 vs. 0) : expected int but got Object - return bool(len(mod.functions) == N) - - -def test_choice(): - # Test setup. - ( - before, - expected, - ) = setup_test_const_folding() - - # Without any argument, default setting will be used for both transformation and constraint functions. - # default transformation function will return the original IRModule without any change. - choice = Choice( - # - transform_func_key="relax.tuning_api.Choice.default_transform_func" - # - constr_func_key="relax.tuning_api.Choice.default_constr_func") - ) - # Load transformation function from the choice and apply it. - after = choice.apply_transform_func(before) - tvm.ir.assert_structural_equal(after, before) - - choice = Choice("testing.apply_fold_constant") - # Load transformation function from the choice and apply it. - after = choice.apply_transform_func(before) - tvm.ir.assert_structural_equal(after, expected) - - # Create a choice that tags global symbol onto target function. - choice = Choice("testing.add_global_symbol", ["addone", "test-symbol"]) - after = choice.apply_transform_func(before) - assert after["addone"].attrs["global_symbol"] == "test-symbol" - # The transformation should be applied with Copy-On-Write. - # So, the original module should be unchanged. - assert before["addone"].attrs["global_symbol"] == "addone" - - # Test choice with impossible constraint - choice = Choice( - transform_func_key="testing.add_global_symbol", - transform_func_args=["addone", "test-symbol"], - constr_func_key="testing.check_num_functions", - constr_func_args=[1000], - ) - # Since the constraint is not met, it should return the original function - after = choice.apply_transform_func(before) - assert after["addone"].attrs["global_symbol"] == "addone" - - # Test choice with the proper constraint - choice = Choice( - transform_func_key="testing.add_global_symbol", - transform_func_args=["addone", "test-symbol"], - constr_func_key="testing.check_num_functions", - constr_func_args=[2], - ) - # Since the constraint is not met, it should return the original function - after = choice.apply_transform_func(before) - assert after["addone"].attrs["global_symbol"] == "test-symbol" - # The original module should be unchanged. - assert before["addone"].attrs["global_symbol"] == "addone" - - # Test roundtrip. - # Export as JSON. - json_obj = choice.as_json() - # Import JSON. - new_choice = Choice.from_json(json_obj) - # Test imported choice - after = new_choice.apply_transform_func(before) - assert after["addone"].attrs["global_symbol"] == "test-symbol" - # The original module should be unchanged. - assert before["addone"].attrs["global_symbol"] == "addone" - - -def test_knob(): - # Test setup. - before, expected = setup_test_const_folding() - - # Users can define a set of choices with list. - choices = [ - Choice("testing.apply_fold_constant"), - Choice(), - ] - - # Define knob. - knob = Knob("TestKnob", choices) - # Check the sanity of decision space. - assert knob.verify(0) - assert knob.verify(1) - assert not knob.verify(3) - - # Check the sanity of each decision. - after_apply = knob.apply(before, 0) - after_noapply = knob.apply(before, 1) - - tvm.ir.assert_structural_equal(after_apply, expected) - tvm.ir.assert_structural_equal(after_noapply, before) - - # Users can define a set of choices with dict. - choices = { - "apply": Choice("testing.apply_fold_constant"), - "noapply": Choice(), - "apply_with_impossible_constr": Choice( - transform_func_key="testing.apply_fold_constant", - constr_func_key="testing.check_num_functions", - constr_func_args=[1000], - ), - } - # Define knob. - knob = Knob("TestKnob", choices) - assert knob.verify("apply") - assert knob.verify("noapply") - assert knob.verify("apply_with_impossible_constr") - assert not knob.verify("INVLAID") - - after_apply = knob.apply(before, "apply") - after_noapply = knob.apply(before, "noapply") - # Because constr was not satisfied, it will return the original IRModule - after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") - tvm.ir.assert_structural_equal(after_apply, expected) - tvm.ir.assert_structural_equal(after_noapply, before) - tvm.ir.assert_structural_equal(after_apply_with_constr, before) - - # Test roundtrip. - # Export as JSON. - json_obj = knob.as_json() - # Import JSON. - new_knob = Knob.from_json(json_obj) - assert new_knob.name == knob.name - # Test imported knob - assert new_knob.verify("apply") - assert new_knob.verify("noapply") - assert new_knob.verify("apply_with_impossible_constr") - assert not new_knob.verify("INVLAID") - - after_apply = new_knob.apply(before, "apply") - after_noapply = new_knob.apply(before, "noapply") - # Because constr was not satisfied, it will return the original IRModule - after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") - tvm.ir.assert_structural_equal(after_apply, expected) - tvm.ir.assert_structural_equal(after_noapply, before) - tvm.ir.assert_structural_equal(after_apply_with_constr, before) - - -def test_trace(): - before, expected = setup_test_const_folding() - - # Define choices and its knob. - choices = { - "apply": Choice( - transform_func_key="testing.apply_fold_constant", - transform_func_args=[], - constr_func_key="testing.check_num_functions", - constr_func_args=[2], - ), - "noapply": Choice(), - } - knob = Knob("TestKnob", choices) - - # Define a Trace with empty decision (transformation) history. - trace = Trace(before) - assert trace.size == 0 - - # Define a Trace with single decision (transformation) history. - trace = Trace(before, [knob], ["noapply"]) - assert trace.size == 1 - tvm.ir.assert_structural_equal(trace.in_mod, before) - tvm.ir.assert_structural_equal(trace.out_mod, before) - - # Add a new knob and its decision to the trace. - # It will update the current trace and returns its new output IRModule. - out: IRModule = trace.add(knob, "noapply") - assert trace.size == 2 - tvm.ir.assert_structural_equal(trace.in_mod, before) - tvm.ir.assert_structural_equal(trace.out_mod, before) - tvm.ir.assert_structural_equal(out, before) - # Assume we assign arbitrary performance number. - trace.set_perf(100) - assert trace.perf == 100 - - # Add a new knob and its decision to the trace. - out: IRModule = trace.add(knob, "apply") - tvm.ir.assert_structural_equal(trace.in_mod, before) - tvm.ir.assert_structural_equal(trace.out_mod, expected) - tvm.ir.assert_structural_equal(out, expected) - - assert trace.size == 3 - # Should be initalized when new knob is applied. - assert trace.perf == -1 - - # Test roundtrip. - # Export as JSON. - json_obj = trace.as_json() - # Import JSON. - new_trace = Trace.from_json(json_obj) - tvm.ir.assert_structural_equal(trace.in_mod, new_trace.in_mod) - assert str(trace) == str(new_trace) - assert new_trace.size == 3 - tvm.ir.assert_structural_equal(trace.out_mod, new_trace.out_mod) - - -def test_trace_wrapper(): - mod = setup_test() - assert isinstance(mod, tvm.IRModule) - assert isinstance(Trace(mod), Trace) - assert isinstance(get_trace(mod), Trace) - assert isinstance(get_trace(mod["main"]), Trace) - assert isinstance(get_trace(mod["addone"]), Trace) - - -def create_tmp_database(tmpdir: str) -> JSONDatabase: - path_workload = osp.join(tmpdir, "workloads.json") - path_tuning_record = osp.join(tmpdir, "tuning_records.json") - path_measurement_record = osp.join(tmpdir, "measurement_records.json") - return JSONDatabase(path_workload, path_tuning_record, path_measurement_record) - - -def test_database(): - def equal_measurement_record(a: List[float], b: List[float]): - assert len(a) == len(b) - for i in range(len(a)): - assert isclose(a[i], b[i], rel_tol=1e-5) - - def equal_tuning_record(a: TuningRecord, b: TuningRecord): - assert str(a.trace) == str(b.trace) - equal_measurement_record(a.run_secs, b.run_secs) - - # Test setup. - ( - mod1, - mod2, - ) = setup_test_const_folding() - knob = Knob("test", {"noapply": Choice()}) - trace = Trace(mod1, [knob, knob], ["noapply", "noapply"]) - target = tvm.target.Target("llvm") - - # Test roundtrip - run_secs = [1.0, 0.9, 0.4] - tuning_record = TuningRecord( - trace, - run_secs, - ) - new_tuning_record = TuningRecord.from_json(json_obj=tuning_record.as_json()) - equal_tuning_record(tuning_record, new_tuning_record) - - with tempfile.TemporaryDirectory() as tmpdir: - database = create_tmp_database(tmpdir) - workload1 = database.commit_workload(mod1) - - database.commit_measurement_record(workload1, target, run_secs) - new_run_secs1 = database.get_measurement_record(workload1, target) - equal_measurement_record(run_secs, new_run_secs1) - workload2 = database.commit_workload(mod2) - new_run_secs2 = database.get_measurement_record(workload2, target) - assert len(new_run_secs2) == 0 - - database.commit_tuning_record(workload1, target, tuning_record) - new_tuning_records = database.get_top_k(workload1, target, top_k=1) - assert len(new_tuning_records) == 1 - equal_tuning_record(tuning_record, new_tuning_records[0]) - new_tuning_records = database.get_top_k(workload1, target, top_k=0) - assert len(new_tuning_records) == 0 - - -def test_default_functions(): - mod = setup_test() - assert isinstance(mod, tvm.IRModule) - - # Define choice, knob, trace. - choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} - knob = Knob("TestKnob", choices) - trace = Trace(mod) - - # Launch a pass pipeline in trace mode. - with tempfile.TemporaryDirectory() as tmpdir: - database = create_tmp_database(tmpdir) - with transform.PassContext(trace=trace, tuning_api_database=database): - # Default generation function expands every valid choice. - candidates = default_generate_candidate([knob], trace) - assert len(candidates) == 2 - - # Default evaluate function uses MetaSchedule builder/runner. - # Since builder/runner are not provided, local builder/runner will be used. - default_evaluate(candidates, "llvm --num-cores=16") - assert PassContext.current().num_evals == 2 - - # Because these candidates are already evaluated, num_evals stays the same. - default_evaluate(candidates, "llvm --num-cores=16") - assert PassContext.current().num_evals == 2 - - # Test with multiple knobs - candidates = default_generate_candidate([knob, knob], trace) - assert len(candidates) == 4 - - # Launch new pass pipeline in trace mode. - with transform.PassContext(trace=trace, tuning_api_database=database): - candidates = default_generate_candidate([knob], trace) - assert len(candidates) == 2 - # Provide tuning pass as an eval pass. - # Note that MockConstFoldingTuningPass() has its own generation function, evaluation function. - # Evaluation would be done in a tornament fashion. - # `default_consider_eval_passes` will convert candidates into the best version by considering eval_passes. - # For example, if we say candidates = [C1, C2] - # `default_consider_eval_passes` will return best form of C1 variant (C11 vs C12) and C2 variant (C21 vs C22) - # that can be generated by eval_passes. - # Assume C11 > C12, C21 < C22, - # new_candidates = [C11, C22] - new_candidates = default_consider_eval_passes( - candidates, [MockConstFoldingTuningPass(eval_passes=[])] - ) - - # len(candidates) == len(new candidates). - assert len(new_candidates) == 2 - # To find the best version of each candidate, it would take 4 evals (C11, C12, C21, C22). - assert PassContext.current().num_evals == 4 - - HeuristicPass = relax.transform.FoldConstant - with transform.PassContext(trace=trace, tuning_api_database=database): - candidates = default_generate_candidate([knob], trace) - assert len(candidates) == 2 - # Provide heuristic pass as an eval pass. - new_candidates = default_consider_eval_passes(candidates, [HeuristicPass()]) - # Since heuristic pass has single decision, it won't need any tornament. - # new_candidates = [C11, C21] - assert len(new_candidates) == 2 - # We only conduct evaluation when its necessary (e.g., choose better candidate in tuning pass). - # Heuristic pass won't conduct any evaluation. - assert PassContext.current().num_evals == 0 - - -# TODO(sunggg): Do we need to serialize pass context as well? -def test_pass_context(): - before, expected = setup_test_const_folding() - HeuristicPass = relax.transform.FoldConstant - # FoldConstant implicitly performs TIR passes (prob for constant evaluation). - # If make_traceable is not provided, the pass infra will make every non-traceable pass traceable by default. - seq = transform.Sequential([HeuristicPass()]) - with transform.PassContext( - trace=Trace(before), - ): - after = seq(before) - tvm.ir.assert_structural_equal(after, expected) - assert PassContext.current().get_trace_stack_size() == 1 - # The exact number of implicit passes might change as TVM develops more passes. - # As of today, this size returns 57. - assert PassContext.current().get_current_trace().size > 1 - - # We can explicitly specify which pass we want to keep track of. - with transform.PassContext(trace=Trace(before), make_traceable=["FoldConstant"]): - after = seq(before) - tvm.ir.assert_structural_equal(after, expected) - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 1 - - # Check the functionality of trace stack. - with transform.PassContext(trace=Trace(before)): - assert PassContext.current().get_trace_stack_size() == 1 - PassContext.current().push_trace(Trace(before)) - assert PassContext.current().get_trace_stack_size() == 2 - PassContext.current().pop_trace() - assert PassContext.current().get_trace_stack_size() == 1 - PassContext.current().pop_trace() - assert PassContext.current().get_trace_stack_size() == 0 - - -# Mock evaluation pass for testing. -# Assigns arbitrary performance number to each candidate. -def mock_evaluate(candidates: List[Trace], target_str: str, ctx: PassContext): - num_evals = 0 - # Evaluation - for candidate in candidates: - # If this candidate is already evaluated, skip the measurement. - if candidate.perf != -1: - continue - - num_evals += 1 - # Assign arbitrary performance. - mock_perf = 100 - (ctx.num_evals + num_evals) - candidate.set_perf(mock_perf) - # Update number of evals for testing. - ctx.inc_num_evals(num_evals) - - -# Mock tuning pass that determines whether to apply relax.transform.FoldConstant(). -# Each pass invocation will generate two candidates for the incoming IRModule. -# In relax pass infra, each pass will define its own way of generating candidates and evaluating them without needing to know how other passes generate its candidate and evaluate them. -# This will significantly alleviate the development process since it is known to be HARD problem to consider the interaction with (potentially hundreds of) other passes. -@ir.transform.module_pass(opt_level=0, traceable=True) -class MockConstFoldingTuningPass(transform.Pass): - def __init__( - self, - f_generate_candidate=None, - f_evaluate=mock_evaluate, - eval_passes: List[transform.Pass] = None, - required: List[transform.Pass] = [], - ): - self.f_generate_candidate = ( - f_generate_candidate if f_generate_candidate else default_generate_candidate - ) - self.f_evaluate = f_evaluate if f_evaluate else default_evaluate - self.eval_passes = eval_passes - self.required = required - - def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: - trace = ctx.pop_trace() - - # Create mock choices for testing. - choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} - # Tuning pass manages a set of transformation functions registered via knob. - knob = Knob("MockTuningKnob", choices) - - candidates = self.f_generate_candidate([knob], trace, self.eval_passes) - self.f_evaluate(candidates, "llvm", ctx) - best_trace = select_best_candidate(candidates) - - ctx.push_trace(best_trace) - return best_trace.out_mod - - -def test_module_pass(): - mod = setup_test() - assert isinstance(mod, tvm.IRModule) - # Test setup - c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) - mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) - HeuristicPass = relax.transform.FoldConstant - - # Tuning pass without any eval_pass. - mock_pass = MockConstFoldingTuningPass(eval_passes=[]) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = mock_pass(mod) - assert PassContext.current().num_evals == 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 1 - - # Heuristic pass should not affect the number of candidates. - mock_pass = MockConstFoldingTuningPass(eval_passes=[HeuristicPass()]) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = mock_pass(mod) - assert PassContext.current().num_evals == 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 2 - - # Joint-optimization will increase the search space in the combinatorial way - mock_pass = MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = mock_pass(mod) - assert PassContext.current().num_evals == 2 * 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 2 - - # Joint-optimization can be nested. - mock_pass = MockConstFoldingTuningPass( - eval_passes=[ - MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) - ] - ) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = mock_pass(mod) - assert PassContext.current().num_evals == 2 * 2 * 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 3 - - # Tuning pass and heuritic passes can be used together. - # Note that heuristic pass won't increate the search space (num_evals). - # It only increases the length of the trace. - mock_pass = MockConstFoldingTuningPass( - eval_passes=[ - HeuristicPass(), - MockConstFoldingTuningPass( - eval_passes=[ - MockConstFoldingTuningPass(eval_passes=[HeuristicPass(), HeuristicPass()]) - ] - ), - ] - ) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = mock_pass(mod) - assert PassContext.current().num_evals == 2 * 2 * 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 6 - - # Users can mix-use sequential application and joint-application. - mock_pass = MockConstFoldingTuningPass( - eval_passes=[ - MockConstFoldingTuningPass(eval_passes=[]), - MockConstFoldingTuningPass(eval_passes=[]), - MockConstFoldingTuningPass(eval_passes=[]), - ] - ) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = mock_pass(mod) - assert PassContext.current().num_evals == 2 * (2 + 2 + 2) - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 4 - - -def test_sequential(): - mod = setup_test() - assert isinstance(mod, tvm.IRModule) - # Test setup. - c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) - mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) - HeuristicPass = relax.transform.FoldConstant - - # Sequential with a single tuning pass should behave same with a single pass. - seq = transform.Sequential([MockConstFoldingTuningPass(eval_passes=[])]) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = seq(mod) - assert PassContext.current().num_evals == 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 1 - - # Sequential pass should increase search space (num_evals) in additive manner. - seq = transform.Sequential( - [ - MockConstFoldingTuningPass(eval_passes=[]), - MockConstFoldingTuningPass(eval_passes=[]), - MockConstFoldingTuningPass(eval_passes=[]), - ] - ) - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = seq(mod) - assert PassContext.current().num_evals == 2 + 2 + 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 3 - - # Heuristic pass will not increase the search space. Just increase trace length. - seq = transform.Sequential( - [ - MockConstFoldingTuningPass(eval_passes=[]), - HeuristicPass(), - MockConstFoldingTuningPass(eval_passes=[]), - MockConstFoldingTuningPass(eval_passes=[]), - HeuristicPass(), - ] - ) - - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = seq(mod) - assert PassContext.current().num_evals == 2 + 2 + 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 5 - - # Users can mix-use sequential application and joint-application. - seq = transform.Sequential( - [ - HeuristicPass(), - MockConstFoldingTuningPass( - eval_passes=[ - MockConstFoldingTuningPass( - eval_passes=[ - MockConstFoldingTuningPass( - eval_passes=[ - HeuristicPass(), - ] - ) - ] - ), - ] - ), - MockConstFoldingTuningPass(eval_passes=[]), - HeuristicPass(), - ] - ) - - with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): - _ = seq(mod) - assert PassContext.current().num_evals == (2 * 2 * 2) + 2 - assert PassContext.current().get_trace_stack_size() == 1 - assert PassContext.current().get_current_trace().size == 7 - - -def test_passes_with_mixed_granularities(): - @tvm.script.ir_module - class MockModule: - @R.function - def f1(x: R.Tensor(("m", "n"), "float32")): - with R.dataflow(): - lv0 = R.multiply(x, x) - gv0 = R.add(x, x) - R.output(gv0) - return gv0 - - @R.function - def main(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): - with R.dataflow(): - lv0 = R.multiply(x, y) - gv0 = R.add(lv0, y) - R.output(gv0) - gv1 = R.multiply(x, y) - gv2 = R.add(gv1, y) - return (gv0, gv1, gv2) - - mod = MockModule - assert isinstance(mod, tvm.IRModule) - - # Helper function for tuning - def pass_func( - mod: IRModule, ctx: PassContext, eval_passes: List[transform.Pass] = None - ) -> IRModule: - trace = ctx.pop_trace() - - # Create mock choices for testing - choices = [Choice(), Choice(), Choice()] - # Tuning pass manages a set of transformation functions registered via knob. - knob = Knob("MockTuningKnob", choices) - - candidates = default_generate_candidate([knob], trace, eval_passes) - mock_evaluate(candidates, "llvm", ctx) - best_trace = select_best_candidate(candidates) - - ctx.push_trace(best_trace) - return best_trace.out_mod - - @ir.transform.module_pass(opt_level=0, traceable=True) - def MockModulePass(mod: IRModule, ctx: PassContext) -> IRModule: - # Input granularity == Candidate granularity. - return pass_func(mod, ctx) - - @relax.transform.function_pass(opt_level=0, traceable=True) - def MockFunctionPass(func: Expr, mod: IRModule, ctx: PassContext) -> Function: - # Input granularity > Candidate granularity. - # Start trace with smaller granularity: IRModule->Function. - ctx.push_trace(Trace(IRModule.from_expr(func))) - # Do something. - pass_func(mod, ctx) - # Pop tuned trace and recover the previous trace. - ctx.pop_trace() - return func - - @relax.transform.dataflowblock_pass(opt_level=0, traceable=True) - def MockDataflowBlockPass( - block: DataflowBlock, mod: IRModule, ctx: PassContext - ) -> DataflowBlock: - # TODO(sunggg): figure out how to create IRModule from DataflowBlock - # Provide random binding for now - x = relax.Var("x", R.Tensor([tvm.tir.Var("n", "int64")], "float32")) - seq_expr = relax.SeqExpr([block], x) - func = relax.Function([x], seq_expr, R.Tensor("float32", ndim=-1)) - ctx.push_trace(Trace(IRModule.from_expr(func))) - # Do something - pass_func(mod, ctx) - ctx.pop_trace() - return block - - seq = transform.Sequential( - [ - MockModulePass, - MockFunctionPass, - MockDataflowBlockPass, - ] - ) - - with transform.PassContext(trace=Trace(mod), make_traceable=[]): - _ = seq(mod) - # Trace length and num eval can be different depending on how each function/dataflow block is treated. - assert PassContext.current().get_trace_stack_size() == 1 - - -if __name__ == "__main__": - pytest.main([__file__])