From 946bfc5c92f35aeee0b8dddc18cdcf1333f0b189 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 28 Apr 2021 23:12:43 +0800 Subject: [PATCH 01/24] [IR][Pass][Instrument] Pass instrument framework This commit provides utilies to instrument passes: 1. Add a new namespace tvm.instrument 2. Introduce PassInstrument and PassInstrumentor to PassContext Example ------- passes_mem = #... Impl of memory instrument passes_time = tvm.instrument.PassesTimeInstrument() with tvm.transform.PassContext( pass_instrumentor=PassInstrumentor([passes_mem, passes_time])): tvm.relay.build(mod, 'llvm') passes_mem.rendor() passes_time.rendor() 3. Integrate existing PassContext::Trace() and timing profile --- include/tvm/ir/instrument.h | 244 +++++++++++++++ include/tvm/ir/transform.h | 45 ++- python/tvm/__init__.py | 2 +- python/tvm/ir/__init__.py | 1 + .../tvm/ir/_ffi_instrument_api.py | 27 +- python/tvm/ir/instrument.py | 128 ++++++++ python/tvm/ir/transform.py | 13 +- src/ir/instrument.cc | 295 ++++++++++++++++++ src/ir/transform.cc | 185 +++-------- src/relay/ir/transform.cc | 12 +- src/tir/ir/transform.cc | 8 +- tests/python/relay/test_pass_instrument.py | 190 +++++++++++ 12 files changed, 956 insertions(+), 194 deletions(-) create mode 100644 include/tvm/ir/instrument.h rename tests/python/relay/test_pass_profiler.py => python/tvm/ir/_ffi_instrument_api.py (51%) create mode 100644 python/tvm/ir/instrument.py create mode 100644 src/ir/instrument.cc create mode 100644 tests/python/relay/test_pass_instrument.py diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h new file mode 100644 index 000000000000..8b8f091298e3 --- /dev/null +++ b/include/tvm/ir/instrument.h @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/instrument.h + * + * This file implements a pass instrument infrastructure, inspired from LLVM and MLIR. + * It inserts instrumentation points between passes run. + * + * Within a pass context (tvm::transfom::PassContext), the instrumentation call sequence will like: + * + * Instrument SetUp + * + * if (Instrument Before Pass) + * Pass Run + * Instrument After Pass + * + * if (Instrument Before Pass) + * Pass Run + * Instrument After Pass + * + * Instrument TearDown + * + * + * Instrument point before pass can determine particular pass is disable or not depends on the + * callback registered. + */ +#ifndef TVM_IR_INSTRUMENT_H_ +#define TVM_IR_INSTRUMENT_H_ + +#include +#include + +#include +#include + +namespace tvm { + +class IRModule; + +// Forward class for PassInstrumentNode methods +namespace transform { +class PassInfo; +} // namespace transform + +namespace instrument { + +/*! + * \brief A callback type for set up or clean up instrument environment. + */ +using InstrumentEnvFunc = runtime::TypedPackedFunc; + +/*! + * \brief A callback template for instrumenting before/after environment. + * \tparam RetTy the return type of callback. + */ +template +using PassInstrumentFunc = + runtime::TypedPackedFunc; + +/*! + * \brief PassInstrumentNode forms an instrument implementation. + * It provides API for users to register callbacks at different instrument point. + * \sa PassInstrument + */ +class PassInstrumentNode : public Object { + public: + /*! \brief Name of this pass instrument object. */ + String name; + + /*! \brief Callback for instrumentation environment set up. */ + InstrumentEnvFunc set_up_callback; + /*! \brief Callback for instrumentation environment clean up. */ + InstrumentEnvFunc tear_down_callback; + + /*! \brief Callback to run before a pass. */ + PassInstrumentFunc run_before_pass_callback; + /*! \brief Callback to run after a pass. */ + PassInstrumentFunc<> run_after_pass_callback; + + /*! + * \brief Register a callback to run at set up point. + * + * \param callback The set up function. + */ + void RegisterSetUpCallback(InstrumentEnvFunc callback) { set_up_callback = std::move(callback); } + + /* + * \brief Register a callback to run at clean up point. + * + * \param callback The clean up function. + */ + void RegisterTearDownCallback(InstrumentEnvFunc callback) { + tear_down_callback = std::move(callback); + } + + /*! + * \brief Register a callback to run before pass run. + * + * \param callback The function to run before pass: return false to skip pass; return true to + * run pass. + */ + void RegisterRunBeforePassCallback(PassInstrumentFunc callback) { + run_before_pass_callback = std::move(callback); + } + + /*! + * \brief Register a callback to run after pass run. + * + * \param callback The function to run after pass. + */ + void RegisterRunAfterPassCallback(PassInstrumentFunc<> callback) { + run_after_pass_callback = std::move(callback); + } + + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + + /*! \brief Set up environment for instrumentation. */ + void SetUp() const; + + /*! \brief Clean up instrumentation environment. */ + void TearDown() const; + + /*! + * \brief Instrument before pass run, determine whether to run the pass or not. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const; + + /*! + * \brief Instrument after pass run. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const; + + static constexpr const char* _type_key = "instrument.PassInstrument"; + TVM_DECLARE_FINAL_OBJECT_INFO(PassInstrumentNode, Object); +}; + +/*! + * \brief Managed reference class for PassInstrumentNode + * \sa PassInstrumentNode + */ +class PassInstrument : public ObjectRef { + public: + /*! + * \brief Constructor + * \param name Name for this instrumentation. + */ + TVM_DLL PassInstrument(String name); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + PassInstrumentNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); +}; + +/*! + * \brief PassInstrumentorNode collects a set of PassInstrument implementations, invokes the + * implementations' methods at different instrument points. + * \sa PassInstrumentor + */ +class PassInstrumentorNode : public Object { + public: + Array pass_instruments; + + void VisitAttrs(AttrVisitor* v) { v->Visit("pass_instruments", &pass_instruments); } + + /*! \brief Set up environment for instrument implementations. */ + void SetUp() const; + + /*! \brief Clean up environment for instrument implementations. */ + void TearDown() const; + + /*! + * \brief Instrument before pass run, determine whether to run the pass or not. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const; + + /*! + * \brief Instrument after pass run. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const; + + static constexpr const char* _type_key = "instrument.PassInstrumentor"; + TVM_DECLARE_FINAL_OBJECT_INFO(PassInstrumentorNode, Object); +}; + +/*! + * \brief Managed reference class for PassInstrumentorNode + * \sa PassInstrumentorNode + */ +class PassInstrumentor : public ObjectRef { + public: + /*! + * \brief Constructor + * \param pass_instruments A set of instrument implementations. + */ + TVM_DLL PassInstrumentor(Array pass_instruments); + + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrumentor, ObjectRef, PassInstrumentorNode); +}; + +} // namespace instrument +} // namespace tvm + +#endif // TVM_IR_INSTRUMENT_H_ diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 50c6f8dd8c3a..32c5dea9b84a 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -58,6 +58,7 @@ #include #include +#include #include #include #include @@ -68,15 +69,6 @@ namespace tvm { namespace transform { -// Forward declare for TraceFunc. -class PassInfo; - -/*! - * \brief A callback for tracing passes, useful for debugging and logging. - */ -using TraceFunc = - runtime::TypedPackedFunc; - /*! * \brief PassContextNode contains the information that a pass can rely on, * such as analysis results. @@ -95,8 +87,9 @@ class PassContextNode : public Object { mutable Optional diag_ctx; /*! \brief Pass specific configurations. */ Map config; - /*! \brief Trace function to be invoked before and after each pass. */ - TraceFunc trace_func; + + /*! \brief Instrumentor contains a list of instrument implementations. */ + instrument::PassInstrumentor pass_instrumentor; PassContextNode() = default; @@ -189,12 +182,32 @@ class PassContext : public ObjectRef { TVM_DLL static PassContext Current(); /*! - * \brief Apply the tracing functions of the context to the module, with the info. - * \param module The IRModule to trace. + * \brief Set up for all the instrument implementations. + */ + TVM_DLL void InstrumentSetUp() const; + + /*! + * \brief Clean up for all the instrument implementations. + */ + TVM_DLL void InstrumentTearDown() const; + + /*! + * \brief Call intrument implementatations before a pass run. + * + * \param mod The module that an optimization pass runs on. * \param info The pass information. - * \param is_before Indicated whether the tracing is before or after a pass. + * + * \return false: the pass is skipped; true: the pass runs. */ - TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + + /*! + * \brief Call instrument implementations after a pass run. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; /*! * \brief Check whether a pass is enabled. @@ -275,7 +288,7 @@ class PassInfoNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); }; -/* +/*! * \brief Managed reference class for PassInfoNode * \sa PassInfoNode */ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 4643062ea8e8..0e67a52c4a32 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -40,6 +40,7 @@ # tvm.ir from .ir import IRModule from .ir import transform +from .ir import instrument from .ir import container from . import ir @@ -67,7 +68,6 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel - # NOTE: This file should be python2 compatible so we can # raise proper error message when user run the package using # an older version of the python diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index a12d3e9855f0..b4cc4421b169 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -31,4 +31,5 @@ from .container import Array, Map from . import transform +from . import instrument from . import diagnostics diff --git a/tests/python/relay/test_pass_profiler.py b/python/tvm/ir/_ffi_instrument_api.py similarity index 51% rename from tests/python/relay/test_pass_profiler.py rename to python/tvm/ir/_ffi_instrument_api.py index acf6c8c50aff..bf62caf30e5a 100644 --- a/tests/python/relay/test_pass_profiler.py +++ b/python/tvm/ir/_ffi_instrument_api.py @@ -14,28 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -import tvm.relay -from tvm.relay import op +"""FFI APIs for tvm.instrument""" +import tvm._ffi - -def test_pass_profiler(): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) - - tvm.transform.enable_pass_profiling() - - mod = tvm.relay.transform.AnnotateSpans()(mod) - mod = tvm.relay.transform.ToANormalForm()(mod) - mod = tvm.relay.transform.InferType()(mod) - - profiles = tvm.transform.render_pass_profiles() - assert "AnnotateSpans" in profiles - assert "ToANormalForm" in profiles - assert "InferType" in profiles - - tvm.transform.clear_pass_profiles() - tvm.transform.disable_pass_profiling() +tvm._ffi._init_api("instrument", __name__) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py new file mode 100644 index 000000000000..e8d894c110b8 --- /dev/null +++ b/python/tvm/ir/instrument.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Common pass instrumentation across IR variants.""" +import tvm._ffi +import tvm.runtime + +from . import _ffi_instrument_api + + +@tvm._ffi.register_object("instrument.PassInstrument") +class PassInstrument(tvm.runtime.Object): + """A pass instrument implementation. + + Parameters + ---------- + name : str + The name for this instrument implementation. + + Examples + -------- + + .. code-block:: python + pi = tvm.instrument.PassInstrument("print-before-after") + + @pi.register_set_up + def set_up(): + pass + + @pi.register_tear_down + def tear_down(): + pass + + @pi.register_run_before_pass + def run_before_pass(mod, info): + print("Before pass: " + info.name) + print(mod) + return True + + @pi.register_run_after_pass + def run_after_pass(mod, info): + print("After pass: " + info.name) + print(mod) + + + See Also + -------- + instrument.PassInstrumentor + """ + + def __init__(self, name): + self.__init_handle_by_constructor__(_ffi_instrument_api.PassInstrument, name) + + def register_set_up(self, callback): + _ffi_instrument_api.RegisterSetUpCallback(self, callback) + + def register_tear_down(self, callback): + _ffi_instrument_api.RegisterTearDownCallback(self, callback) + + def register_run_before_pass(self, callback): + _ffi_instrument_api.RegisterRunBeforePassCallback(self, callback) + + def register_run_after_pass(self, callback): + _ffi_instrument_api.RegisterRunAfterPassCallback(self, callback) + + +@tvm._ffi.register_object("instrument.PassInstrumentor") +class PassInstrumentor(tvm.runtime.Object): + """A pass instrumentor collects a set of pass instrument implementations. + + Parameters + ---------- + pass_instruments : List[tvm.instrument.PassInstrument] + List of instrumentation to run within pass context + + Examples + -------- + .. code-block:: python + + passes_mem = #... Impl of memory instrument + passes_time = tvm.instrument.PassesTimeInstrument() + + with tvm.transform.PassContext( + pass_instrumentor=tvm.instrument.PassInstrumentor([passes_mem, passes_time])): + tvm.relay.build(mod, 'llvm') + + print(passes_time.rendor()) + + See Also + ------- + instrument.PassInstrument + instrument.PassesTimeInstrument + """ + + def __init__(self, pass_instruments): + self.__init_handle_by_constructor__(_ffi_instrument_api.PassInstrumentor, pass_instruments) + + +@tvm._ffi.register_object("instrument.PassInstrument") +class PassesTimeInstrument(tvm.runtime.Object): + """A wrapper to create a passes time instrument that implemented in C++""" + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassesTimeInstrument) + + @staticmethod + def render(): + """Retrieve rendered time profile result + Returns + ------- + string : string + The rendered string result of time profiles + """ + return _ffi_instrument_api.RenderTimePassProfiles() diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 36e06eeb8b23..6fbd9b2da3a2 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -65,12 +65,20 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + pass_instrumentor : Optional[tvm.instrument.PassInstrumentor] + The pass instrumentor that collects pass instrument implementations + config : Optional[Dict[str, Object]] Additional configurations for specific passes. """ def __init__( - self, opt_level=2, required_pass=None, disabled_pass=None, trace=None, config=None + self, + opt_level=2, + required_pass=None, + disabled_pass=None, + pass_instrumentor=None, + config=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -82,7 +90,7 @@ def __init__( config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, trace, config + _ffi_transform_api.PassContext, opt_level, required, disabled, pass_instrumentor, config ) def __enter__(self): @@ -189,6 +197,7 @@ def __init__(self, *args, **kwargs): # initialize handle in cass pass_cls creation failed.fg self.handle = None inst = pass_cls(*args, **kwargs) + # it is important not to capture self to # avoid a cyclic dependency def _pass_func(mod, ctx): diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc new file mode 100644 index 000000000000..254d70a315c7 --- /dev/null +++ b/src/ir/instrument.cc @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/instrument.cc + * \brief Infrastructure for instrumentation. + */ +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace instrument { + +PassInstrument::PassInstrument(String name) { + auto pi = make_object(); + pi->name = std::move(name); + data_ = std::move(pi); +} + +void PassInstrumentNode::SetUp() const { + if (set_up_callback != nullptr) { + set_up_callback(); + } +} + +void PassInstrumentNode::TearDown() const { + if (tear_down_callback != nullptr) { + tear_down_callback(); + } +} + +bool PassInstrumentNode::RunBeforePass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (run_before_pass_callback == nullptr) { + return true; + } + + return run_before_pass_callback(ir_module, pass_info); +} + +void PassInstrumentNode::RunAfterPass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (run_after_pass_callback != nullptr) { + run_after_pass_callback(ir_module, pass_info); + } +} + +PassInstrumentor::PassInstrumentor(Array pass_instruments) { + auto n = make_object(); + n->pass_instruments = std::move(pass_instruments); + data_ = std::move(n); +} + +void PassInstrumentorNode::SetUp() const { + for (PassInstrument pi : pass_instruments) { + pi->SetUp(); + } +} + +void PassInstrumentorNode::TearDown() const { + for (PassInstrument pi : pass_instruments) { + pi->TearDown(); + } +} + +bool PassInstrumentorNode::RunBeforePass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + for (PassInstrument pi : pass_instruments) { + if (!pi->RunBeforePass(ir_module, pass_info)) { + return false; + } + } + + return true; +} + +void PassInstrumentorNode::RunAfterPass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + for (PassInstrument pi : pass_instruments) { + pi->RunAfterPass(ir_module, pass_info); + } +} + +TVM_REGISTER_NODE_TYPE(PassInstrumentNode); + +TVM_REGISTER_GLOBAL("instrument.PassInstrument").set_body_typed([](String name) { + return PassInstrument(name); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->name; + }); + +TVM_REGISTER_NODE_TYPE(PassInstrumentorNode); + +TVM_REGISTER_GLOBAL("instrument.PassInstrumentor") + .set_body_typed([](Array pass_instruments) { + return PassInstrumentor(pass_instruments); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + + p->stream << "\n PassInstrumentor ["; + for (PassInstrument pi : node->pass_instruments) { + p->stream << pi << " "; + } + p->stream << "] \n"; + }); + +TVM_REGISTER_GLOBAL("instrument.RegisterSetUpCallback") + .set_body_typed([](PassInstrument pass_instrument, InstrumentEnvFunc callback) { + pass_instrument->RegisterSetUpCallback(callback); + }); + +TVM_REGISTER_GLOBAL("instrument.RegisterTearDownCallback") + .set_body_typed([](PassInstrument pass_instrument, InstrumentEnvFunc callback) { + pass_instrument->RegisterTearDownCallback(callback); + }); + +TVM_REGISTER_GLOBAL("instrument.RegisterRunBeforePassCallback") + .set_body_typed([](PassInstrument pass_instrument, PassInstrumentFunc callback) { + pass_instrument->RegisterRunBeforePassCallback(callback); + }); + +TVM_REGISTER_GLOBAL("instrument.RegisterRunAfterPassCallback") + .set_body_typed([](PassInstrument pass_instrument, PassInstrumentFunc<> callback) { + pass_instrument->RegisterRunAfterPassCallback(callback); + }); + +/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ +struct PassProfile { + // TODO(@altanh): expose PassProfile through TVM Object API + using Clock = std::chrono::steady_clock; + using Duration = std::chrono::duration; + using Time = std::chrono::time_point; + + /*! \brief The name of the pass being profiled. */ + String name; + /*! \brief The time when the pass was entered. */ + Time start; + /*! \brief The time when the pass completed. */ + Time end; + /*! \brief The total duration of the pass, i.e. end - start. */ + Duration duration; + /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ + std::vector children; + + explicit PassProfile(String name) + : name(name), start(Clock::now()), end(Clock::now()), children() {} + + /*! \brief Gets the PassProfile of the currently executing pass. */ + static PassProfile* Current(); + /*! \brief Pushes a new PassProfile with the given pass name. */ + static void EnterPass(String name); + /*! \brief Pops the current PassProfile. */ + static void ExitPass(); +}; + +struct PassProfileThreadLocalEntry { + /*! \brief The placeholder top-level PassProfile. */ + PassProfile root; + /*! \brief The stack of PassProfiles for nested passes currently running. */ + std::stack profile_stack; + + PassProfileThreadLocalEntry() : root("root") {} +}; + +/*! \brief Thread local store to hold the pass profiling data. */ +typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; + +void PassProfile::EnterPass(String name) { + PassProfile* cur = PassProfile::Current(); + cur->children.emplace_back(name); + PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +} + +void PassProfile::ExitPass() { + PassProfile* cur = PassProfile::Current(); + ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; + cur->end = PassProfile::Clock::now(); + cur->duration = std::chrono::duration_cast(cur->end - cur->start); + PassProfileThreadLocalStore::Get()->profile_stack.pop(); +} + +PassProfile* PassProfile::Current() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + if (!entry->profile_stack.empty()) { + return entry->profile_stack.top(); + } else { + return &entry->root; + } +} + +String RenderPassProfiles() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; + + if (entry->root.children.empty()) { + LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; + return String(); + } + + // (depth, parent_duration, pass) + std::stack> profiles; + + // push top level passes + PassProfile::Duration top_dur(0); + for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { + top_dur += it->duration; + } + for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { + profiles.push(std::make_tuple(0, top_dur, &*it)); + } + + std::ostringstream os; + os << std::fixed; + + while (profiles.size() > 0) { + size_t depth; + PassProfile::Duration parent_duration; + PassProfile* profile; + std::tie(depth, parent_duration, profile) = profiles.top(); + profiles.pop(); + + // indent depth + for (size_t i = 0; i < depth; ++i) { + os << "\t"; + } + + // calculate time spent in pass itself (excluding sub-passes), and push children + PassProfile::Duration self_duration = profile->duration; + for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { + self_duration -= it->duration; + profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); + } + + double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; + double total_pct = profile->duration.count() / top_dur.count() * 100.0; + + os << profile->name << ": "; + os << std::setprecision(0); + os << profile->duration.count() << "us [" << self_duration.count() << "us] "; + os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; + } + + return os.str(); +} + +TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); + +TVM_REGISTER_GLOBAL("instrument.MakePassesTimeInstrument").set_body_typed([]() { + auto pi = PassInstrument("PassesTimeInstrument"); + + // No set up function for this time instrumentation. + + pi->RegisterTearDownCallback([]() { PassProfileThreadLocalStore::Get()->root.children.clear(); }); + + pi->RegisterRunBeforePassCallback([](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::EnterPass(pass_info->name); + return true; + }); + + pi->RegisterRunAfterPassCallback( + [](const IRModule&, const transform::PassInfo&) { PassProfile::ExitPass(); }); + + return pi; +}); + +} // namespace instrument +} // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 48f13bc81df4..f05042b11bc4 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -58,12 +58,14 @@ typedef dmlc::ThreadLocalStore RelayPassContextThre void PassContext::EnterWithScope() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); + InstrumentSetUp(); } void PassContext::ExitWithScope() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); + InstrumentTearDown(); entry->context_stack.pop(); } @@ -162,170 +164,55 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassContext PassContext::Create() { return PassContext(make_object()); } -void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { +void PassContext::InstrumentSetUp() const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); + if (pass_ctx_node->pass_instrumentor.defined()) { + pass_ctx_node->pass_instrumentor->SetUp(); } } -class ModulePass; - -/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ -struct PassProfile { - // TODO(@altanh): expose PassProfile through TVM Object API - using Clock = std::chrono::steady_clock; - using Duration = std::chrono::duration; - using Time = std::chrono::time_point; - - /*! \brief The name of the pass being profiled. */ - String name; - /*! \brief The time when the pass was entered. */ - Time start; - /*! \brief The time when the pass completed. */ - Time end; - /*! \brief The total duration of the pass, i.e. end - start. */ - Duration duration; - /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ - std::vector children; - - explicit PassProfile(String name) - : name(name), start(Clock::now()), end(Clock::now()), children() {} - - /*! \brief Gets the PassProfile of the currently executing pass. */ - static PassProfile* Current(); - /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); - /*! \brief Pops the current PassProfile. */ - static void ExitPass(); -}; - -struct PassProfileThreadLocalEntry { - /*! \brief The placeholder top-level PassProfile. */ - PassProfile root; - /*! \brief The stack of PassProfiles for nested passes currently running. */ - std::stack profile_stack; - /*! \brief Whether or not pass profiling is active. */ - bool active; - - PassProfileThreadLocalEntry() : root("root"), active(false) {} -}; - -/*! \brief Thread local store to hold the pass profiling data. */ -typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; - -void PassProfile::EnterPass(String name) { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - cur->children.emplace_back(name); - PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +void PassContext::InstrumentTearDown() const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->pass_instrumentor.defined()) { + pass_ctx_node->pass_instrumentor->TearDown(); + } } -void PassProfile::ExitPass() { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; - cur->end = std::move(PassProfile::Clock::now()); - cur->duration = std::chrono::duration_cast(cur->end - cur->start); - PassProfileThreadLocalStore::Get()->profile_stack.pop(); +bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->pass_instrumentor.defined()) { + if (!pass_ctx_node->pass_instrumentor->RunBeforePass(ir_module, pass_info)) { + return false; + } + } + return true; } -PassProfile* PassProfile::Current() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - if (!entry->profile_stack.empty()) { - return entry->profile_stack.top(); - } else { - return &entry->root; +void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->pass_instrumentor.defined()) { + pass_ctx_node->pass_instrumentor->RunAfterPass(ir_module, pass_info); } } IRModule Pass::operator()(IRModule mod) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); + // PassProfile::EnterPass(node->Info()->name); auto ret = node->operator()(std::move(mod)); - PassProfile::ExitPass(); + // PassProfile::ExitPass(); return std::move(ret); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); + // PassProfile::EnterPass(node->Info()->name); auto ret = node->operator()(std::move(mod), pass_ctx); - PassProfile::ExitPass(); + // PassProfile::ExitPass(); return std::move(ret); } -String RenderPassProfiles() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; - - if (entry->root.children.empty()) { - LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; - return String(); - } - - // (depth, parent_duration, pass) - std::stack> profiles; - - // push top level passes - PassProfile::Duration top_dur(0); - for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { - top_dur += it->duration; - } - for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { - profiles.push(std::make_tuple(0, top_dur, &*it)); - } - - std::ostringstream os; - os << std::fixed; - - while (profiles.size() > 0) { - size_t depth; - PassProfile::Duration parent_duration; - PassProfile* profile; - std::tie(depth, parent_duration, profile) = profiles.top(); - profiles.pop(); - - // indent depth - for (size_t i = 0; i < depth; ++i) { - os << "\t"; - } - - // calculate time spent in pass itself (excluding sub-passes), and push children - PassProfile::Duration self_duration = profile->duration; - for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { - self_duration -= it->duration; - profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); - } - - double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; - double total_pct = profile->duration.count() / top_dur.count() * 100.0; - - os << profile->name << ": "; - os << std::setprecision(0); - os << profile->duration.count() << "us [" << self_duration.count() << "us] "; - os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; - } - - return os.str(); -} - -TVM_REGISTER_GLOBAL("transform.render_pass_profiles").set_body_typed(RenderPassProfiles); - -TVM_REGISTER_GLOBAL("transform.clear_pass_profiles").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->root.children.clear(); -}); - -TVM_REGISTER_GLOBAL("transform.enable_pass_profiling").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->active = true; -}); - -TVM_REGISTER_GLOBAL("transform.disable_pass_profiling").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->active = false; -}); - /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes @@ -464,12 +351,19 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c << "The diagnostic context was set at the top of this block this is a bug."; const PassInfo& pass_info = Info(); + ICHECK(mod.defined()) << "The input module must be set."; + + if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { + DLOG(INFO) << "Skipping function pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; + + pass_ctx->diag_ctx = previous; + return mod; + } + DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - ICHECK(mod.defined()) << "The input module must be set."; - - pass_ctx.Trace(mod, pass_info, true); mod = pass_func(std::move(mod), pass_ctx); ICHECK(mod.defined()) << "The return value of a module pass must be set."; @@ -480,7 +374,7 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.Trace(mod, pass_info, false); + pass_ctx.InstrumentAfterPass(mod, pass_info); return mod; } @@ -621,13 +515,14 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, - TraceFunc trace_func, Optional> config) { + instrument::PassInstrumentor pass_instrumentor, + Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); - pctx->trace_func = std::move(trace_func); + pctx->pass_instrumentor = std::move(pass_instrumentor); if (config.defined()) { pctx->config = config.value(); } diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 596f812e25af..999c9c4fe39e 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -130,11 +130,17 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) ICHECK(mod.defined()); + if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { + DLOG(INFO) << "Skipping function pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; + + pass_ctx->diag_ctx = previous; + return mod; + } + DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - pass_ctx.Trace(mod, pass_info, true); - // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); @@ -159,7 +165,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.Trace(updated_mod, pass_info, false); + pass_ctx.InstrumentAfterPass(updated_mod, pass_info); // TODO(@jroesch): move away from eager type checking for performance reasons // make issue. diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 95c40f9a3c8e..5fafc7abc863 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -87,9 +87,11 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - const PassInfo& pass_info = Info(); + // const PassInfo& pass_info = Info(); ICHECK(mod.defined()); - pass_ctx.Trace(mod, pass_info, true); + if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { + return mod; + } std::vector deleted_list; IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); @@ -112,7 +114,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& gv : deleted_list) { func_dict->erase(gv); } - pass_ctx.Trace(mod, pass_info, false); + pass_ctx.InstrumentAfterPass(mod, pass_info); return mod; } diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py new file mode 100644 index 000000000000..ab2e10929129 --- /dev/null +++ b/tests/python/relay/test_pass_instrument.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.relay +from tvm.relay import op +from tvm.ir.instrument import PassesTimeInstrument, PassInstrument, PassInstrumentor + + +def test_pass_time_instrument(): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + mod = tvm.IRModule.from_expr(e3 + e2) + + time_instrument = PassesTimeInstrument() + with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([time_instrument])): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = time_instrument.render() + assert "AnnotateSpans" in profiles + assert "ToANormalForm" in profiles + assert "InferType" in profiles + + +def test_custom_instrument(capsys): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + mod = tvm.IRModule.from_expr(e3 + e2) + + def custom_pi(): + pi = PassInstrument("MyTest") + + @pi.register_set_up + def set_up(): + print("set up") + + @pi.register_tear_down + def tear_down(): + print("tear down") + + @pi.register_run_before_pass + def run_before_pass(mod, info): + print("run before " + info.name) + return True + + @pi.register_run_after_pass + def run_after_pass(mod, info): + print("run after " + info.name) + + return pi + + with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([custom_pi()])): + mod = tvm.relay.transform.InferType()(mod) + + output = "set up\n" "run before InferType\n" "run after InferType\n" "tear down\n" + assert capsys.readouterr().out == output + + +def test_disable_pass(capsys): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + mod = tvm.IRModule.from_expr(e3 + e2) + + def custom_pi(): + pi = PassInstrument("MyTest") + + @pi.register_run_before_pass + def run_before_pass(mod, info): + # Only run pass name contains "InferType" + if "InferType" not in info.name: + return False + + print(info.name) + return True + + return pi + + with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([custom_pi()])): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + assert capsys.readouterr().out == "InferType\n" + + +def test_multiple_instrument(capsys): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + mod = tvm.IRModule.from_expr(e3 + e2) + + def custom_pi(skip_pass_name): + def create_custom_pi(): + pi = PassInstrument("Don't care") + + @pi.register_run_before_pass + def run_before_pass(mod, info): + if skip_pass_name in info.name: + return False + + return True + + return pi + + return create_custom_pi() + + skip_annotate = custom_pi("AnnotateSpans") + skip_anf = custom_pi("ToANormalForm") + + print_pass_name = PassInstrument("PrintPassName") + + @print_pass_name.register_run_before_pass + def run_before_pass(mod, info): + print(info.name) + return True + + with tvm.transform.PassContext( + pass_instrumentor=PassInstrumentor([skip_annotate, skip_anf, print_pass_name]) + ): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + assert capsys.readouterr().out == "InferType\n" + + +def test_instrument_pass_counts(capsys): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + mod = tvm.IRModule.from_expr(e3 + e2) + + class PassesCounter(PassInstrument): + def __init__(self): + super().__init__("PassesCounter") + super().register_set_up(self.__set_up) + super().register_tear_down(self.__tear_down) + super().register_run_before_pass(self.__run_before_pass) + super().register_run_after_pass(self.__run_after_pass) + self.__clear() + + def __clear(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def __set_up(self): + self.__clear() + + def __tear_down(self): + self.__clear() + + def __run_before_pass(self, mod, info): + self.run_before_count = self.run_before_count + 1 + return True + + def __run_after_pass(self, mod, info): + self.run_after_count = self.run_after_count + 1 + + passes_counter = PassesCounter() + with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([passes_counter])): + tvm.relay.build(mod, "llvm") + assert passes_counter.run_after_count != 0 + assert passes_counter.run_after_count == passes_counter.run_before_count + + # Out of pass context scope, should be reset + assert passes_counter.run_before_count == 0 + assert passes_counter.run_after_count == 0 From 0ad48d4f1e04f048bbae12daaec1cc100db458cd Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sat, 1 May 2021 00:14:34 +0800 Subject: [PATCH 02/24] [IR][Pass][Instrument] Fix python test_pass_manager.py --- tests/python/relay/test_pass_manager.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 5a29d1acd171..b404331d54fa 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -25,6 +25,7 @@ from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform +from tvm.ir import instrument as _instrument from tvm.relay.testing import run_infer_type import tvm.testing @@ -535,11 +536,14 @@ def test_print_ir(capfd): __TRACE_COUNTER__ = 0 +pi = _instrument.PassInstrument("my_instrument") -def _tracer(module, info, is_before): + +@pi.register_run_before_pass +def _tracer(module, info): global __TRACE_COUNTER__ - if bool(is_before): - __TRACE_COUNTER__ += 1 + __TRACE_COUNTER__ += 1 + return True def test_print_debug_callback(): @@ -562,7 +566,9 @@ def test_print_debug_callback(): assert __TRACE_COUNTER__ == 0 mod = tvm.IRModule({"main": func}) - with tvm.transform.PassContext(opt_level=3, trace=_tracer): + with tvm.transform.PassContext( + opt_level=3, pass_instrumentor=_instrument.PassInstrumentor([pi]) + ): mod = seq(mod) # TODO(@jroesch): when we remove new fn pass behavior we need to remove From dc33594ff0a1eb55d90c498a55b8948abc9286d4 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 16 May 2021 17:24:28 +0800 Subject: [PATCH 03/24] Fix comment --- include/tvm/ir/instrument.h | 118 ++--------------- include/tvm/ir/transform.h | 4 +- python/tvm/ir/instrument.py | 141 ++++++++++++--------- python/tvm/ir/transform.py | 12 +- src/ir/instrument.cc | 115 ++++------------- src/ir/transform.cc | 31 +++-- tests/python/relay/test_pass_instrument.py | 98 ++++++-------- tests/python/relay/test_pass_manager.py | 17 +-- 8 files changed, 196 insertions(+), 340 deletions(-) diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 8b8f091298e3..f92cb5730ecb 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -20,7 +20,7 @@ /*! * \file tvm/ir/instrument.h * - * This file implements a pass instrument infrastructure, inspired from LLVM and MLIR. + * This file introduces a pass instrument infrastructure, inspired from LLVM and MLIR. * It inserts instrumentation points between passes run. * * Within a pass context (tvm::transfom::PassContext), the instrumentation call sequence will like: @@ -61,19 +61,6 @@ class PassInfo; namespace instrument { -/*! - * \brief A callback type for set up or clean up instrument environment. - */ -using InstrumentEnvFunc = runtime::TypedPackedFunc; - -/*! - * \brief A callback template for instrumenting before/after environment. - * \tparam RetTy the return type of callback. - */ -template -using PassInstrumentFunc = - runtime::TypedPackedFunc; - /*! * \brief PassInstrumentNode forms an instrument implementation. * It provides API for users to register callbacks at different instrument point. @@ -85,49 +72,16 @@ class PassInstrumentNode : public Object { String name; /*! \brief Callback for instrumentation environment set up. */ - InstrumentEnvFunc set_up_callback; + runtime::TypedPackedFunc set_up_callback; /*! \brief Callback for instrumentation environment clean up. */ - InstrumentEnvFunc tear_down_callback; + runtime::TypedPackedFunc tear_down_callback; /*! \brief Callback to run before a pass. */ - PassInstrumentFunc run_before_pass_callback; + runtime::TypedPackedFunc + run_before_pass_callback; /*! \brief Callback to run after a pass. */ - PassInstrumentFunc<> run_after_pass_callback; - - /*! - * \brief Register a callback to run at set up point. - * - * \param callback The set up function. - */ - void RegisterSetUpCallback(InstrumentEnvFunc callback) { set_up_callback = std::move(callback); } - - /* - * \brief Register a callback to run at clean up point. - * - * \param callback The clean up function. - */ - void RegisterTearDownCallback(InstrumentEnvFunc callback) { - tear_down_callback = std::move(callback); - } - - /*! - * \brief Register a callback to run before pass run. - * - * \param callback The function to run before pass: return false to skip pass; return true to - * run pass. - */ - void RegisterRunBeforePassCallback(PassInstrumentFunc callback) { - run_before_pass_callback = std::move(callback); - } - - /*! - * \brief Register a callback to run after pass run. - * - * \param callback The function to run after pass. - */ - void RegisterRunAfterPassCallback(PassInstrumentFunc<> callback) { - run_after_pass_callback = std::move(callback); - } + runtime::TypedPackedFunc + run_after_pass_callback; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } @@ -155,7 +109,7 @@ class PassInstrumentNode : public Object { void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const; static constexpr const char* _type_key = "instrument.PassInstrument"; - TVM_DECLARE_FINAL_OBJECT_INFO(PassInstrumentNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); }; /*! @@ -182,62 +136,6 @@ class PassInstrument : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); }; -/*! - * \brief PassInstrumentorNode collects a set of PassInstrument implementations, invokes the - * implementations' methods at different instrument points. - * \sa PassInstrumentor - */ -class PassInstrumentorNode : public Object { - public: - Array pass_instruments; - - void VisitAttrs(AttrVisitor* v) { v->Visit("pass_instruments", &pass_instruments); } - - /*! \brief Set up environment for instrument implementations. */ - void SetUp() const; - - /*! \brief Clean up environment for instrument implementations. */ - void TearDown() const; - - /*! - * \brief Instrument before pass run, determine whether to run the pass or not. - * - * \param mod The module that an optimization pass runs on. - * \param info The pass information. - * - * \return true to run the pass; false to skip the pass. - */ - bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const; - - /*! - * \brief Instrument after pass run. - * - * \param mod The module that an optimization pass runs on. - * \param info The pass information. - * - * \return true to run the pass; false to skip the pass. - */ - void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const; - - static constexpr const char* _type_key = "instrument.PassInstrumentor"; - TVM_DECLARE_FINAL_OBJECT_INFO(PassInstrumentorNode, Object); -}; - -/*! - * \brief Managed reference class for PassInstrumentorNode - * \sa PassInstrumentorNode - */ -class PassInstrumentor : public ObjectRef { - public: - /*! - * \brief Constructor - * \param pass_instruments A set of instrument implementations. - */ - TVM_DLL PassInstrumentor(Array pass_instruments); - - TVM_DEFINE_OBJECT_REF_METHODS(PassInstrumentor, ObjectRef, PassInstrumentorNode); -}; - } // namespace instrument } // namespace tvm diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 32c5dea9b84a..611cba2f41ce 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -88,8 +88,8 @@ class PassContextNode : public Object { /*! \brief Pass specific configurations. */ Map config; - /*! \brief Instrumentor contains a list of instrument implementations. */ - instrument::PassInstrumentor pass_instrumentor; + /*! \brief A list of pass instrument implementations. */ + Array instruments; PassContextNode() = default; diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index e8d894c110b8..4d7b3f290bcc 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -16,6 +16,10 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass instrumentation across IR variants.""" +import types +import inspect +import functools + import tvm._ffi import tvm.runtime @@ -26,88 +30,111 @@ class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. - Parameters - ---------- - name : str - The name for this instrument implementation. + Users don't need to interact with this class directly. + Instead, a `PassInstrument` instance should be created through `pass_instrument`. - Examples + See Also -------- + `pass_instrument` + """ - .. code-block:: python - pi = tvm.instrument.PassInstrument("print-before-after") - @pi.register_set_up - def set_up(): - pass +def _wrap_class_pass_instrument(pi_cls): + """Wrap a python class as pass instrument""" - @pi.register_tear_down - def tear_down(): - pass + class PyPassInstrument(PassInstrument): + """Internal wrapper class to create a class instance.""" - @pi.register_run_before_pass - def run_before_pass(mod, info): - print("Before pass: " + info.name) - print(mod) - return True + def __init__(self, *args, **kwargs): + # initialize handle in cass pi_cls creation failed.fg + self.handle = None + inst = pi_cls(*args, **kwargs) - @pi.register_run_after_pass - def run_after_pass(mod, info): - print("After pass: " + info.name) - print(mod) + # check method declartion within class, if found, wrap it. + def create_method(method): + if hasattr(inst, method) and inspect.ismethod(getattr(inst, method)): + def func(*args): + return getattr(inst, method)(*args) - See Also - -------- - instrument.PassInstrumentor - """ - - def __init__(self, name): - self.__init_handle_by_constructor__(_ffi_instrument_api.PassInstrument, name) + func.__name__ = "_" + method + return func + return None - def register_set_up(self, callback): - _ffi_instrument_api.RegisterSetUpCallback(self, callback) + # create runtime pass instrument object + # reister instance's run_before_pass, run_after_pass, set_up and tear_down method to it if present. + self.__init_handle_by_constructor__( + _ffi_instrument_api.PassInstrument, + pi_cls.__name__, + create_method("run_before_pass"), + create_method("run_after_pass"), + create_method("set_up"), + create_method("tear_down"), + ) - def register_tear_down(self, callback): - _ffi_instrument_api.RegisterTearDownCallback(self, callback) + self._inst = inst - def register_run_before_pass(self, callback): - _ffi_instrument_api.RegisterRunBeforePassCallback(self, callback) + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) - def register_run_after_pass(self, callback): - _ffi_instrument_api.RegisterRunAfterPassCallback(self, callback) + functools.update_wrapper(PyPassInstrument.__init__, pi_cls.__init__) + PyPassInstrument.__name__ = pi_cls.__name__ + PyPassInstrument.__doc__ = pi_cls.__doc__ + PyPassInstrument.__module__ = pi_cls.__module__ + return PyPassInstrument -@tvm._ffi.register_object("instrument.PassInstrumentor") -class PassInstrumentor(tvm.runtime.Object): - """A pass instrumentor collects a set of pass instrument implementations. +def pass_instrument(pi_cls=None): + """Decorate a pass instrument. Parameters ---------- - pass_instruments : List[tvm.instrument.PassInstrument] - List of instrumentation to run within pass context + pi_class : Examples -------- - .. code-block:: python - - passes_mem = #... Impl of memory instrument - passes_time = tvm.instrument.PassesTimeInstrument() + The following code block decorates a pass instrument class. - with tvm.transform.PassContext( - pass_instrumentor=tvm.instrument.PassInstrumentor([passes_mem, passes_time])): - tvm.relay.build(mod, 'llvm') + .. code-block:: python + @tvm.instrument.pass_instrument + class SkipPass: + def __init__(self, skip_pass_name): + self.skip_pass_name = skip_pass_name + + # Uncomment to customize + # def set_up(): + # pass + + # Uncomment to customize + # def tear_down(): + # pass + + # If pass name contains keyword, skip it by return False. (return True: not skip) + def run_before_pass(mod, pass_info): + if self.skip_pass_name in pass_info.name: + return False + return True + + # Uncomment to customize + # def run_after_pass(mod, pass_info): + # pass + + skip_annotate = SkipPass("AnnotateSpans") + with tvm.transform.PassContext(instruments=[skip_annotate]): + tvm.relay.build(mod, "llvm") + """ - print(passes_time.rendor()) + def create_pass_instrument(pi_cls): + if not inspect.isclass(pi_cls): + raise TypeError("pi_cls must be a class") - See Also - ------- - instrument.PassInstrument - instrument.PassesTimeInstrument - """ + name = pi_cls.__name__ + return _wrap_class_pass_instrument(pi_cls) - def __init__(self, pass_instruments): - self.__init_handle_by_constructor__(_ffi_instrument_api.PassInstrumentor, pass_instruments) + if pi_cls: + return create_pass_instrument(pi_cls) + return create_pass_instrument @tvm._ffi.register_object("instrument.PassInstrument") diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 6fbd9b2da3a2..c131303011b1 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -65,8 +65,8 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. - pass_instrumentor : Optional[tvm.instrument.PassInstrumentor] - The pass instrumentor that collects pass instrument implementations + instruments : Optional[Union[List[PassInstrument], Set[PassInstrument], Tuple[PassInstrument]]] + The list of pass instrument implementations. config : Optional[Dict[str, Object]] Additional configurations for specific passes. @@ -77,7 +77,7 @@ def __init__( opt_level=2, required_pass=None, disabled_pass=None, - pass_instrumentor=None, + instruments=None, config=None, ): required = list(required_pass) if required_pass else [] @@ -88,9 +88,13 @@ def __init__( if not isinstance(disabled, (list, tuple)): raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") + instruments = list(instruments) if instruments else [] + if not isinstance(instruments, (list, tuple)): + raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, pass_instrumentor, config + _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config ) def __enter__(self): diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 254d70a315c7..ee321e04be89 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -66,47 +66,24 @@ void PassInstrumentNode::RunAfterPass(const IRModule& ir_module, } } -PassInstrumentor::PassInstrumentor(Array pass_instruments) { - auto n = make_object(); - n->pass_instruments = std::move(pass_instruments); - data_ = std::move(n); -} - -void PassInstrumentorNode::SetUp() const { - for (PassInstrument pi : pass_instruments) { - pi->SetUp(); - } -} - -void PassInstrumentorNode::TearDown() const { - for (PassInstrument pi : pass_instruments) { - pi->TearDown(); - } -} - -bool PassInstrumentorNode::RunBeforePass(const IRModule& ir_module, - const transform::PassInfo& pass_info) const { - for (PassInstrument pi : pass_instruments) { - if (!pi->RunBeforePass(ir_module, pass_info)) { - return false; - } - } - - return true; -} - -void PassInstrumentorNode::RunAfterPass(const IRModule& ir_module, - const transform::PassInfo& pass_info) const { - for (PassInstrument pi : pass_instruments) { - pi->RunAfterPass(ir_module, pass_info); - } -} - TVM_REGISTER_NODE_TYPE(PassInstrumentNode); -TVM_REGISTER_GLOBAL("instrument.PassInstrument").set_body_typed([](String name) { - return PassInstrument(name); -}); +TVM_REGISTER_GLOBAL("instrument.PassInstrument") + .set_body_typed([](String name, + runtime::TypedPackedFunc + run_before_pass, + runtime::TypedPackedFunc + run_after_pass, + runtime::TypedPackedFunc set_up, + runtime::TypedPackedFunc tear_down) { + auto pi = PassInstrument(name); + pi->run_before_pass_callback = std::move(run_before_pass); + pi->run_after_pass_callback = std::move(run_after_pass); + + pi->set_up_callback = std::move(set_up); + pi->tear_down_callback = std::move(tear_down); + return pi; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -114,44 +91,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << node->name; }); -TVM_REGISTER_NODE_TYPE(PassInstrumentorNode); - -TVM_REGISTER_GLOBAL("instrument.PassInstrumentor") - .set_body_typed([](Array pass_instruments) { - return PassInstrumentor(pass_instruments); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - - p->stream << "\n PassInstrumentor ["; - for (PassInstrument pi : node->pass_instruments) { - p->stream << pi << " "; - } - p->stream << "] \n"; - }); - -TVM_REGISTER_GLOBAL("instrument.RegisterSetUpCallback") - .set_body_typed([](PassInstrument pass_instrument, InstrumentEnvFunc callback) { - pass_instrument->RegisterSetUpCallback(callback); - }); - -TVM_REGISTER_GLOBAL("instrument.RegisterTearDownCallback") - .set_body_typed([](PassInstrument pass_instrument, InstrumentEnvFunc callback) { - pass_instrument->RegisterTearDownCallback(callback); - }); - -TVM_REGISTER_GLOBAL("instrument.RegisterRunBeforePassCallback") - .set_body_typed([](PassInstrument pass_instrument, PassInstrumentFunc callback) { - pass_instrument->RegisterRunBeforePassCallback(callback); - }); - -TVM_REGISTER_GLOBAL("instrument.RegisterRunAfterPassCallback") - .set_body_typed([](PassInstrument pass_instrument, PassInstrumentFunc<> callback) { - pass_instrument->RegisterRunAfterPassCallback(callback); - }); - /*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ struct PassProfile { // TODO(@altanh): expose PassProfile through TVM Object API @@ -274,20 +213,22 @@ String RenderPassProfiles() { TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); TVM_REGISTER_GLOBAL("instrument.MakePassesTimeInstrument").set_body_typed([]() { - auto pi = PassInstrument("PassesTimeInstrument"); - - // No set up function for this time instrumentation. - - pi->RegisterTearDownCallback([]() { PassProfileThreadLocalStore::Get()->root.children.clear(); }); - - pi->RegisterRunBeforePassCallback([](const IRModule&, const transform::PassInfo& pass_info) { + auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { PassProfile::EnterPass(pass_info->name); return true; - }); + }; + + auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::ExitPass(); + }; - pi->RegisterRunAfterPassCallback( - [](const IRModule&, const transform::PassInfo&) { PassProfile::ExitPass(); }); + auto tear_down = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; + + auto pi = PassInstrument("PassesTimeInstrument"); + pi->run_before_pass_callback = run_before_pass; + pi->run_after_pass_callback = run_after_pass; + pi->tear_down_callback = tear_down; return pi; }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index f05042b11bc4..145d468930b0 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -166,32 +166,41 @@ PassContext PassContext::Create() { return PassContext(make_objectoperator->(); - if (pass_ctx_node->pass_instrumentor.defined()) { - pass_ctx_node->pass_instrumentor->SetUp(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->SetUp(); + } } } void PassContext::InstrumentTearDown() const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->pass_instrumentor.defined()) { - pass_ctx_node->pass_instrumentor->TearDown(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->TearDown(); + } } } bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->pass_instrumentor.defined()) { - if (!pass_ctx_node->pass_instrumentor->RunBeforePass(ir_module, pass_info)) { - return false; + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + if (!pi->RunBeforePass(ir_module, pass_info)) { + return false; + } } + return true; } return true; } void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->pass_instrumentor.defined()) { - pass_ctx_node->pass_instrumentor->RunAfterPass(ir_module, pass_info); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunAfterPass(ir_module, pass_info); + } } } @@ -515,14 +524,14 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, - instrument::PassInstrumentor pass_instrumentor, + Array instruments, Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); - pctx->pass_instrumentor = std::move(pass_instrumentor); + pctx->instruments = std::move(instruments); if (config.defined()) { pctx->config = config.value(); } diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index ab2e10929129..b2e6bbe23eba 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -17,7 +17,7 @@ import tvm import tvm.relay from tvm.relay import op -from tvm.ir.instrument import PassesTimeInstrument, PassInstrument, PassInstrumentor +from tvm.ir.instrument import PassesTimeInstrument, PassInstrument, pass_instrument def test_pass_time_instrument(): @@ -28,7 +28,7 @@ def test_pass_time_instrument(): mod = tvm.IRModule.from_expr(e3 + e2) time_instrument = PassesTimeInstrument() - with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([time_instrument])): + with tvm.transform.PassContext(instruments=[time_instrument]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) mod = tvm.relay.transform.InferType()(mod) @@ -46,29 +46,22 @@ def test_custom_instrument(capsys): e3 = op.multiply(e1, e1 / e2) mod = tvm.IRModule.from_expr(e3 + e2) - def custom_pi(): - pi = PassInstrument("MyTest") - - @pi.register_set_up - def set_up(): + @pass_instrument + class MyTest: + def set_up(self): print("set up") - @pi.register_tear_down - def tear_down(): + def tear_down(self): print("tear down") - @pi.register_run_before_pass - def run_before_pass(mod, info): + def run_before_pass(self, mod, info): print("run before " + info.name) return True - @pi.register_run_after_pass - def run_after_pass(mod, info): + def run_after_pass(self, mod, info): print("run after " + info.name) - return pi - - with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([custom_pi()])): + with tvm.transform.PassContext(instruments=[MyTest()]): mod = tvm.relay.transform.InferType()(mod) output = "set up\n" "run before InferType\n" "run after InferType\n" "tear down\n" @@ -82,11 +75,9 @@ def test_disable_pass(capsys): e3 = op.multiply(e1, e1 / e2) mod = tvm.IRModule.from_expr(e3 + e2) - def custom_pi(): - pi = PassInstrument("MyTest") - - @pi.register_run_before_pass - def run_before_pass(mod, info): + @pass_instrument + class CustomPI: + def run_before_pass(self, mod, info): # Only run pass name contains "InferType" if "InferType" not in info.name: return False @@ -94,9 +85,7 @@ def run_before_pass(mod, info): print(info.name) return True - return pi - - with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([custom_pi()])): + with tvm.transform.PassContext(instruments=[CustomPI()]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) mod = tvm.relay.transform.InferType()(mod) @@ -111,34 +100,28 @@ def test_multiple_instrument(capsys): e3 = op.multiply(e1, e1 / e2) mod = tvm.IRModule.from_expr(e3 + e2) - def custom_pi(skip_pass_name): - def create_custom_pi(): - pi = PassInstrument("Don't care") - - @pi.register_run_before_pass - def run_before_pass(mod, info): - if skip_pass_name in info.name: - return False - - return True - - return pi + @pass_instrument + class SkipPass: + def __init__(self, skip_pass_name): + self.skip_pass_name = skip_pass_name - return create_custom_pi() + def run_before_pass(self, mod, info): + if self.skip_pass_name in info.name: + return False + return True - skip_annotate = custom_pi("AnnotateSpans") - skip_anf = custom_pi("ToANormalForm") + skip_annotate = SkipPass("AnnotateSpans") + skip_anf = SkipPass("ToANormalForm") - print_pass_name = PassInstrument("PrintPassName") + @pass_instrument + class PrintPassName: + def run_before_pass(self, mod, info): + print(info.name) + return True - @print_pass_name.register_run_before_pass - def run_before_pass(mod, info): - print(info.name) - return True + print_pass_name = PrintPassName() - with tvm.transform.PassContext( - pass_instrumentor=PassInstrumentor([skip_annotate, skip_anf, print_pass_name]) - ): + with tvm.transform.PassContext(instruments=[skip_annotate, skip_anf, print_pass_name]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) mod = tvm.relay.transform.InferType()(mod) @@ -153,34 +136,31 @@ def test_instrument_pass_counts(capsys): e3 = op.multiply(e1, e1 / e2) mod = tvm.IRModule.from_expr(e3 + e2) - class PassesCounter(PassInstrument): + @pass_instrument + class PassesCounter: def __init__(self): - super().__init__("PassesCounter") - super().register_set_up(self.__set_up) - super().register_tear_down(self.__tear_down) - super().register_run_before_pass(self.__run_before_pass) - super().register_run_after_pass(self.__run_after_pass) - self.__clear() + self.run_before_count = 0 + self.run_after_count = 0 def __clear(self): self.run_before_count = 0 self.run_after_count = 0 - def __set_up(self): + def set_up(self): self.__clear() - def __tear_down(self): + def tear_down(self): self.__clear() - def __run_before_pass(self, mod, info): + def run_before_pass(self, mod, info): self.run_before_count = self.run_before_count + 1 return True - def __run_after_pass(self, mod, info): + def run_after_pass(self, mod, info): self.run_after_count = self.run_after_count + 1 passes_counter = PassesCounter() - with tvm.transform.PassContext(pass_instrumentor=PassInstrumentor([passes_counter])): + with tvm.transform.PassContext(instruments=[passes_counter]): tvm.relay.build(mod, "llvm") assert passes_counter.run_after_count != 0 assert passes_counter.run_after_count == passes_counter.run_before_count diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index b404331d54fa..36b63fc8c010 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -536,14 +536,13 @@ def test_print_ir(capfd): __TRACE_COUNTER__ = 0 -pi = _instrument.PassInstrument("my_instrument") - -@pi.register_run_before_pass -def _tracer(module, info): - global __TRACE_COUNTER__ - __TRACE_COUNTER__ += 1 - return True +@tvm.instrument.pass_instrument +class MyInstrument: + def run_before_pass(self, module, info): + global __TRACE_COUNTER__ + __TRACE_COUNTER__ += 1 + return True def test_print_debug_callback(): @@ -566,9 +565,7 @@ def test_print_debug_callback(): assert __TRACE_COUNTER__ == 0 mod = tvm.IRModule({"main": func}) - with tvm.transform.PassContext( - opt_level=3, pass_instrumentor=_instrument.PassInstrumentor([pi]) - ): + with tvm.transform.PassContext(opt_level=3, instruments=[MyInstrument()]): mod = seq(mod) # TODO(@jroesch): when we remove new fn pass behavior we need to remove From 62ffb722139463f393c0f612035979657850d243 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 16 May 2021 17:55:00 +0800 Subject: [PATCH 04/24] Fix lint --- python/tvm/ir/instrument.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 4d7b3f290bcc..c032b858545b 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass instrumentation across IR variants.""" -import types import inspect import functools @@ -62,7 +61,8 @@ def func(*args): return None # create runtime pass instrument object - # reister instance's run_before_pass, run_after_pass, set_up and tear_down method to it if present. + # reister instance's run_before_pass, run_after_pass, set_up and tear_down method + # to it if present. self.__init_handle_by_constructor__( _ffi_instrument_api.PassInstrument, pi_cls.__name__, @@ -129,7 +129,6 @@ def create_pass_instrument(pi_cls): if not inspect.isclass(pi_cls): raise TypeError("pi_cls must be a class") - name = pi_cls.__name__ return _wrap_class_pass_instrument(pi_cls) if pi_cls: From d358e991fdaac13fc9999aa32d70f97b30e24484 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 16 May 2021 21:39:47 +0800 Subject: [PATCH 05/24] Fix test_pass_annotation --- tests/python/relay/test_pass_annotation.py | 21 ++++++++++++++------- tests/python/relay/test_pass_instrument.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index a9c31f5ccedd..4f8d22fd50bb 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -23,12 +23,19 @@ from tvm.contrib import graph_executor from tvm.relay.expr_functor import ExprMutator from tvm.relay import transform +from tvm.ir.instrument import pass_instrument import tvm.testing -def _trace(module, metadata, _): - if metadata.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() +@tvm.instrument.pass_instrument +class Trace: + def run_before_pass(module, pass_info): + if pass_info.name == "ManifestAlloc": + pass # import pdb; pdb.set_trace() + + def run_after_pass(module, pass_info): + if pass_info.name == "ManifestAlloc": + pass # import pdb; pdb.set_trace() def check_graph_executor( @@ -49,7 +56,7 @@ def check_graph_executor( def check_vm_runtime(target, ref_res, device, func, params, config, opt_level, expected_index=None): - with tvm.transform.PassContext(opt_level=opt_level, trace=_trace, config=config): + with tvm.transform.PassContext(opt_level=opt_level, instruments=[Trace()], config=config): mod = tvm.IRModule() mod["main"] = func exe = relay.vm.compile(mod, target) @@ -186,7 +193,7 @@ def check_annotated_graph(annotated_func, expected_func): def test_conv_network(): - R"""The network is as following: + r"""The network is as following: data1 data2 | | conv2d conv2d @@ -389,7 +396,7 @@ def get_func(): return func def test_fuse_log_add(device, tgt): - """ Only log and add are fused.""" + """Only log and add are fused.""" fallback_device = tvm.device("cpu") target = {"cpu": "llvm", device: tgt} cpu_dev = fallback_device @@ -530,7 +537,7 @@ def test_fallback_all_operators(device, tgt): def run_unpropagatable_graph(dev, tgt): - R"""The network is as following: + r"""The network is as following: a b c d \ / \ / add mul diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index b2e6bbe23eba..84b16e73fea5 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -17,7 +17,7 @@ import tvm import tvm.relay from tvm.relay import op -from tvm.ir.instrument import PassesTimeInstrument, PassInstrument, pass_instrument +from tvm.ir.instrument import PassesTimeInstrument, pass_instrument def test_pass_time_instrument(): From ffcfef09b70f153852c2bf6c343e920345938760 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 16 May 2021 21:42:54 +0800 Subject: [PATCH 06/24] Fix test_pass_annotation.py --- tests/python/relay/test_pass_annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 4f8d22fd50bb..e3df2aabf0cc 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -28,7 +28,7 @@ @tvm.instrument.pass_instrument -class Trace: +class Trace(): def run_before_pass(module, pass_info): if pass_info.name == "ManifestAlloc": pass # import pdb; pdb.set_trace() From 0cce2733bff7ccb8f7151a0d5de7ed7335d0693c Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 16 May 2021 21:46:39 +0800 Subject: [PATCH 07/24] Fix lint --- tests/python/relay/test_pass_annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index e3df2aabf0cc..4f8d22fd50bb 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -28,7 +28,7 @@ @tvm.instrument.pass_instrument -class Trace(): +class Trace: def run_before_pass(module, pass_info): if pass_info.name == "ManifestAlloc": pass # import pdb; pdb.set_trace() From 9ed716c8359ab613937fa448d739bd1d8f4b4ec5 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 16 May 2021 23:58:46 +0800 Subject: [PATCH 08/24] Fix test_pass_annotation.py --- python/tvm/ir/instrument.py | 8 ++++---- tests/python/relay/test_pass_annotation.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index c032b858545b..852ac4e499e6 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -103,21 +103,21 @@ def __init__(self, skip_pass_name): self.skip_pass_name = skip_pass_name # Uncomment to customize - # def set_up(): + # def set_up(self): # pass # Uncomment to customize - # def tear_down(): + # def tear_down(self): # pass # If pass name contains keyword, skip it by return False. (return True: not skip) - def run_before_pass(mod, pass_info): + def run_before_pass(self, mod, pass_info): if self.skip_pass_name in pass_info.name: return False return True # Uncomment to customize - # def run_after_pass(mod, pass_info): + # def run_after_pass(self, mod, pass_info): # pass skip_annotate = SkipPass("AnnotateSpans") diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 4f8d22fd50bb..72f32a6bbe08 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -29,11 +29,11 @@ @tvm.instrument.pass_instrument class Trace: - def run_before_pass(module, pass_info): + def run_before_pass(self, module, pass_info): if pass_info.name == "ManifestAlloc": pass # import pdb; pdb.set_trace() - def run_after_pass(module, pass_info): + def run_after_pass(self, module, pass_info): if pass_info.name == "ManifestAlloc": pass # import pdb; pdb.set_trace() From 9f34b7cdc988020d9cecff1ef5ca88f39af87a0b Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Mon, 17 May 2021 00:03:40 +0800 Subject: [PATCH 09/24] Fix test_pass_annotation.py --- tests/python/relay/test_pass_annotation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 72f32a6bbe08..cbad1c2d5776 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -32,6 +32,7 @@ class Trace: def run_before_pass(self, module, pass_info): if pass_info.name == "ManifestAlloc": pass # import pdb; pdb.set_trace() + return True def run_after_pass(self, module, pass_info): if pass_info.name == "ManifestAlloc": From 3e407c0673a7a471b6c52f0fce8a07832673f103 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Tue, 18 May 2021 12:50:12 +0800 Subject: [PATCH 10/24] Fix review comments --- include/tvm/ir/instrument.h | 42 +++------------ python/tvm/ir/instrument.py | 2 +- src/ir/instrument.cc | 102 +++++++++++++++++++++++++++++++----- 3 files changed, 96 insertions(+), 50 deletions(-) diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index f92cb5730ecb..20bbf369dfd5 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -68,28 +68,13 @@ namespace instrument { */ class PassInstrumentNode : public Object { public: - /*! \brief Name of this pass instrument object. */ - String name; - - /*! \brief Callback for instrumentation environment set up. */ - runtime::TypedPackedFunc set_up_callback; - /*! \brief Callback for instrumentation environment clean up. */ - runtime::TypedPackedFunc tear_down_callback; - - /*! \brief Callback to run before a pass. */ - runtime::TypedPackedFunc - run_before_pass_callback; - /*! \brief Callback to run after a pass. */ - runtime::TypedPackedFunc - run_after_pass_callback; - - void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + virtual ~PassInstrumentNode() {} /*! \brief Set up environment for instrumentation. */ - void SetUp() const; + virtual void SetUp() const = 0; /*! \brief Clean up instrumentation environment. */ - void TearDown() const; + virtual void TearDown() const = 0; /*! * \brief Instrument before pass run, determine whether to run the pass or not. @@ -98,7 +83,7 @@ class PassInstrumentNode : public Object { * * \return true to run the pass; false to skip the pass. */ - bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const; + virtual bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; /*! * \brief Instrument after pass run. @@ -106,7 +91,9 @@ class PassInstrumentNode : public Object { * \param mod The module that an optimization pass runs on. * \param info The pass information. */ - void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const; + virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; + + void VisitAttrs(AttrVisitor* v) {} static constexpr const char* _type_key = "instrument.PassInstrument"; TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); @@ -118,21 +105,6 @@ class PassInstrumentNode : public Object { */ class PassInstrument : public ObjectRef { public: - /*! - * \brief Constructor - * \param name Name for this instrumentation. - */ - TVM_DLL PassInstrument(String name); - - /*! - * \brief mutable accessor. - * \return mutable access pointer. - */ - PassInstrumentNode* operator->() { - ICHECK(get() != nullptr); - return static_cast(get_mutable()); - } - TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); }; diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 852ac4e499e6..9b1993a48303 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -64,7 +64,7 @@ def func(*args): # reister instance's run_before_pass, run_after_pass, set_up and tear_down method # to it if present. self.__init_handle_by_constructor__( - _ffi_instrument_api.PassInstrument, + _ffi_instrument_api.NamedPassInstrument, pi_cls.__name__, create_method("run_before_pass"), create_method("run_after_pass"), diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index ee321e04be89..3a3f38982e6c 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -32,26 +32,100 @@ namespace tvm { namespace instrument { -PassInstrument::PassInstrument(String name) { - auto pi = make_object(); +/*! + * \brief A named PassInstrument implementation + * \sa NamedPassInstrument + */ +class NamedPassInstrumentNode : public PassInstrumentNode { + public: + /*! \brief Name of this pass instrument object. */ + String name; + + /*! \brief Callback for instrumentation environment set up. */ + runtime::TypedPackedFunc set_up_callback; + /*! \brief Callback for instrumentation environment clean up. */ + runtime::TypedPackedFunc tear_down_callback; + + /*! \brief Callback to run before a pass. */ + runtime::TypedPackedFunc + run_before_pass_callback; + /*! \brief Callback to run after a pass. */ + runtime::TypedPackedFunc + run_after_pass_callback; + + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + + /*! \brief Set up environment for instrumentation. */ + void SetUp() const final; + + /*! \brief Clean up instrumentation environment. */ + void TearDown() const final; + + /*! + * \brief Instrument before pass run, determine whether to run the pass or not. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final; + + /*! + * \brief Instrument after pass run. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final; + + static constexpr const char* _type_key = "instrument.NamedPassInstrument"; + TVM_DECLARE_FINAL_OBJECT_INFO(NamedPassInstrumentNode, PassInstrumentNode); +}; + +/*! + * \brief Managed reference class for NamedPassInstrumentNode + * \sa NamedPassInstrumentNode + */ +class NamedPassInstrument : public PassInstrument { + public: + /*! + * \brief Constructor + * \param name Name for this instrumentation. + */ + TVM_DLL NamedPassInstrument(String name); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + NamedPassInstrumentNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(NamedPassInstrument, PassInstrument, NamedPassInstrumentNode); +}; + +NamedPassInstrument::NamedPassInstrument(String name) { + auto pi = make_object(); pi->name = std::move(name); data_ = std::move(pi); } -void PassInstrumentNode::SetUp() const { +void NamedPassInstrumentNode::SetUp() const { if (set_up_callback != nullptr) { set_up_callback(); } } -void PassInstrumentNode::TearDown() const { +void NamedPassInstrumentNode::TearDown() const { if (tear_down_callback != nullptr) { tear_down_callback(); } } -bool PassInstrumentNode::RunBeforePass(const IRModule& ir_module, - const transform::PassInfo& pass_info) const { +bool NamedPassInstrumentNode::RunBeforePass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { if (run_before_pass_callback == nullptr) { return true; } @@ -59,16 +133,16 @@ bool PassInstrumentNode::RunBeforePass(const IRModule& ir_module, return run_before_pass_callback(ir_module, pass_info); } -void PassInstrumentNode::RunAfterPass(const IRModule& ir_module, - const transform::PassInfo& pass_info) const { +void NamedPassInstrumentNode::RunAfterPass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { if (run_after_pass_callback != nullptr) { run_after_pass_callback(ir_module, pass_info); } } -TVM_REGISTER_NODE_TYPE(PassInstrumentNode); +TVM_REGISTER_NODE_TYPE(NamedPassInstrumentNode); -TVM_REGISTER_GLOBAL("instrument.PassInstrument") +TVM_REGISTER_GLOBAL("instrument.NamedPassInstrument") .set_body_typed([](String name, runtime::TypedPackedFunc run_before_pass, @@ -76,7 +150,7 @@ TVM_REGISTER_GLOBAL("instrument.PassInstrument") run_after_pass, runtime::TypedPackedFunc set_up, runtime::TypedPackedFunc tear_down) { - auto pi = PassInstrument(name); + auto pi = NamedPassInstrument(name); pi->run_before_pass_callback = std::move(run_before_pass); pi->run_after_pass_callback = std::move(run_after_pass); @@ -86,8 +160,8 @@ TVM_REGISTER_GLOBAL("instrument.PassInstrument") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); p->stream << node->name; }); @@ -224,7 +298,7 @@ TVM_REGISTER_GLOBAL("instrument.MakePassesTimeInstrument").set_body_typed([]() { auto tear_down = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; - auto pi = PassInstrument("PassesTimeInstrument"); + auto pi = NamedPassInstrument("PassesTimeInstrument"); pi->run_before_pass_callback = run_before_pass; pi->run_after_pass_callback = run_after_pass; From 5b578b1285862ed301a69405fce116cbfd431b64 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Tue, 18 May 2021 18:32:41 +0800 Subject: [PATCH 11/24] Fix tutorial use_pass_infra.py --- tutorials/dev/use_pass_infra.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 6a33d14e38c8..7c7c4af23bfa 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -273,14 +273,17 @@ def visit_constant(self, c): # An example is below. -def print_ir(mod, info, is_before): +@tvm.instrument.pass_instrument +class PrintIR: """Print the name of the pass, the IR, only before passes execute.""" - if is_before: + + def run_before_pass(self, mod, info): print("Running pass: {}", info) print(mod) + return True -with tvm.transform.PassContext(opt_level=3, trace=print_ir): +with tvm.transform.PassContext(opt_level=3, instruments=[PrintIR()]): with tvm.target.Target("llvm"): # Perform the optimizations. mod = seq(mod) From e48a13db1efc391f0227d2d7ef3664e63f37aa2a Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 19 May 2021 07:26:56 +0800 Subject: [PATCH 12/24] Fix review comments --- include/tvm/ir/instrument.h | 42 ++++++++++---------- python/tvm/ir/transform.py | 4 +- src/ir/instrument.cc | 52 +++++++++++++------------ src/ir/transform.cc | 4 -- tests/python/relay/test_pass_manager.py | 35 +++++++++++------ 5 files changed, 75 insertions(+), 62 deletions(-) diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 20bbf369dfd5..c4122dcf4dff 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -20,26 +20,8 @@ /*! * \file tvm/ir/instrument.h * - * This file introduces a pass instrument infrastructure, inspired from LLVM and MLIR. - * It inserts instrumentation points between passes run. - * - * Within a pass context (tvm::transfom::PassContext), the instrumentation call sequence will like: - * - * Instrument SetUp - * - * if (Instrument Before Pass) - * Pass Run - * Instrument After Pass - * - * if (Instrument Before Pass) - * Pass Run - * Instrument After Pass - * - * Instrument TearDown - * - * - * Instrument point before pass can determine particular pass is disable or not depends on the - * callback registered. + * This file introduces a pass instrument infrastructure, inspired by LLVM and MLIR. + * It inserts instrumentation points around passes. */ #ifndef TVM_IR_INSTRUMENT_H_ #define TVM_IR_INSTRUMENT_H_ @@ -63,7 +45,25 @@ namespace instrument { /*! * \brief PassInstrumentNode forms an instrument implementation. - * It provides API for users to register callbacks at different instrument point. + * It provides API for users to register callbacks at different instrumentation points. + * + * Within a pass context (tvm::transfom::PassContext), the instrumentation call sequence will like: + * + * Instrument SetUp + * + * if (Instrument Before Pass1()) + * Pass1() + * Instrument After Pass1() + * + * if (Instrument Before Pass2()) + * Pass2() + * Instrument After Pass2() + * + * Instrument TearDown + * + * The `Before Pass` instrumentation point can selectively disable passes by returning true (to + * enable) or false (to disable). + * * \sa PassInstrument */ class PassInstrumentNode : public Object { diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index c131303011b1..c3490cc77108 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -65,7 +65,7 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. - instruments : Optional[Union[List[PassInstrument], Set[PassInstrument], Tuple[PassInstrument]]] + instruments : Optional[Sequence[PassInstrument]] The list of pass instrument implementations. config : Optional[Dict[str, Object]] @@ -90,7 +90,7 @@ def __init__( instruments = list(instruments) if instruments else [] if not isinstance(instruments, (list, tuple)): - raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") + raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") config = config if config else None self.__init_handle_by_constructor__( diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 3a3f38982e6c..a22fb2a179e7 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -91,24 +91,38 @@ class NamedPassInstrument : public PassInstrument { /*! * \brief Constructor * \param name Name for this instrumentation. + * \param run_before_pass_callback Callback to call before a pass run. + * \param run_after_pass_callback Callback to call after a pass run. + * \param set_up_callback Callback to call when entering pass context. + * \param tear_down_callback Callback to call when exiting pass context. */ - TVM_DLL NamedPassInstrument(String name); - - /*! - * \brief mutable accessor. - * \return mutable access pointer. - */ - NamedPassInstrumentNode* operator->() { - ICHECK(get() != nullptr); - return static_cast(get_mutable()); - } + TVM_DLL NamedPassInstrument( + String name, + runtime::TypedPackedFunc + run_before_pass_callback, + runtime::TypedPackedFunc + run_after_pass_callback, + runtime::TypedPackedFunc set_up_callback, + runtime::TypedPackedFunc tear_down_callback); TVM_DEFINE_OBJECT_REF_METHODS(NamedPassInstrument, PassInstrument, NamedPassInstrumentNode); }; -NamedPassInstrument::NamedPassInstrument(String name) { +NamedPassInstrument::NamedPassInstrument( + String name, + runtime::TypedPackedFunc + run_before_pass_callback, + runtime::TypedPackedFunc + run_after_pass_callback, + runtime::TypedPackedFunc set_up_callback, + runtime::TypedPackedFunc tear_down_callback) { auto pi = make_object(); pi->name = std::move(name); + pi->run_before_pass_callback = std::move(run_before_pass_callback); + pi->run_after_pass_callback = std::move(run_after_pass_callback); + + pi->set_up_callback = std::move(set_up_callback); + pi->tear_down_callback = std::move(tear_down_callback); data_ = std::move(pi); } @@ -150,13 +164,7 @@ TVM_REGISTER_GLOBAL("instrument.NamedPassInstrument") run_after_pass, runtime::TypedPackedFunc set_up, runtime::TypedPackedFunc tear_down) { - auto pi = NamedPassInstrument(name); - pi->run_before_pass_callback = std::move(run_before_pass); - pi->run_after_pass_callback = std::move(run_after_pass); - - pi->set_up_callback = std::move(set_up); - pi->tear_down_callback = std::move(tear_down); - return pi; + return NamedPassInstrument(name, run_before_pass, run_after_pass, set_up, tear_down); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -298,12 +306,8 @@ TVM_REGISTER_GLOBAL("instrument.MakePassesTimeInstrument").set_body_typed([]() { auto tear_down = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; - auto pi = NamedPassInstrument("PassesTimeInstrument"); - pi->run_before_pass_callback = run_before_pass; - pi->run_after_pass_callback = run_after_pass; - - pi->tear_down_callback = tear_down; - return pi; + return NamedPassInstrument("PassesTimeInstrument", run_before_pass, run_after_pass, + /* set_up */ nullptr, tear_down); }); } // namespace instrument diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 145d468930b0..8fb6e1254b7d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -207,18 +207,14 @@ void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& IRModule Pass::operator()(IRModule mod) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - // PassProfile::EnterPass(node->Info()->name); auto ret = node->operator()(std::move(mod)); - // PassProfile::ExitPass(); return std::move(ret); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - // PassProfile::EnterPass(node->Info()->name); auto ret = node->operator()(std::move(mod), pass_ctx); - // PassProfile::ExitPass(); return std::move(ret); } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 36b63fc8c010..1fc722aab6da 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -534,19 +534,27 @@ def test_print_ir(capfd): assert "multiply" in out -__TRACE_COUNTER__ = 0 +@tvm.instrument.pass_instrument +class PassCounter: + def __init__(self): + # Just setting a garbage value to test set_up callback + self.counts = 1234 + def set_up(self): + self.counts = 0 + + def tear_down(self): + self.counts = 0 -@tvm.instrument.pass_instrument -class MyInstrument: def run_before_pass(self, module, info): - global __TRACE_COUNTER__ - __TRACE_COUNTER__ += 1 + self.counts += 1 return True + def get_counts(self): + return self.counts + def test_print_debug_callback(): - global __TRACE_COUNTER__ shape = (1, 2, 3) tp = relay.TensorType(shape, "float32") x = relay.var("x", tp) @@ -562,15 +570,20 @@ def test_print_debug_callback(): ] ) - assert __TRACE_COUNTER__ == 0 mod = tvm.IRModule({"main": func}) - with tvm.transform.PassContext(opt_level=3, instruments=[MyInstrument()]): + pass_counter = PassCounter() + with tvm.transform.PassContext(opt_level=3, instruments=[pass_counter]): + # Should be reseted when entering pass context + assert pass_counter.get_counts() == 0 mod = seq(mod) - # TODO(@jroesch): when we remove new fn pass behavior we need to remove - # change this back to 3 - assert __TRACE_COUNTER__ == 5 + # TODO(@jroesch): when we remove new fn pass behavior we need to remove + # change this back to 3 + assert pass_counter.get_counts() == 5 + + # Should be cleanned up after exiting pass context + assert pass_counter.get_counts() == 0 if __name__ == "__main__": From 09c062aa031a37690cfdbfb69e688ac41ec2a468 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 19 May 2021 12:08:53 +0800 Subject: [PATCH 13/24] Fix review comments --- include/tvm/ir/instrument.h | 25 +++++++++++++++++++------ src/ir/transform.cc | 3 +-- src/tir/ir/transform.cc | 2 +- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index c4122dcf4dff..9f84dc4e33d1 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -49,22 +49,35 @@ namespace instrument { * * Within a pass context (tvm::transfom::PassContext), the instrumentation call sequence will like: * - * Instrument SetUp + * Instrument SetUp() * - * if (Instrument Before Pass1()) + * if (Instrument Before Pass()) * Pass1() - * Instrument After Pass1() + * Instrument After Pass() * - * if (Instrument Before Pass2()) + * if (Instrument Before Pass()) * Pass2() - * Instrument After Pass2() + * Instrument After Pass() + * + * Instrument TearDown() * - * Instrument TearDown * * The `Before Pass` instrumentation point can selectively disable passes by returning true (to * enable) or false (to disable). * + * If there are multiple pass instrumentations provided, `Before Pass` callbacks are applied in + * order. If one return false, then the pass will be skipped: + * + * for (auto pi : PassInstruments) + * if (pi->BeforePass()) + * return False // Disable pass + * + * return True // All ok, enable pass + * + * * \sa PassInstrument + * \sa PassContextNode::InstrumentBeforePass() + * \sa src/ir/transform.cc */ class PassInstrumentNode : public Object { public: diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 8fb6e1254b7d..886c91ad2d6b 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -190,7 +190,6 @@ bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo return false; } } - return true; } return true; } @@ -359,7 +358,7 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c ICHECK(mod.defined()) << "The input module must be set."; if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { - DLOG(INFO) << "Skipping function pass : " << pass_info->name + DLOG(INFO) << "Skipping module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; pass_ctx->diag_ctx = previous; diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 5fafc7abc863..7f1deae60925 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -87,7 +87,7 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - // const PassInfo& pass_info = Info(); + const PassInfo& pass_info = Info(); ICHECK(mod.defined()); if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { return mod; From 6dcd55913b63e6104329fedf2e0d4e3bfa93992e Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 19 May 2021 12:19:01 +0800 Subject: [PATCH 14/24] Fix typo --- include/tvm/ir/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 611cba2f41ce..ca25c18ed3f7 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -192,7 +192,7 @@ class PassContext : public ObjectRef { TVM_DLL void InstrumentTearDown() const; /*! - * \brief Call intrument implementatations before a pass run. + * \brief Call intrument implementations before a pass run. * * \param mod The module that an optimization pass runs on. * \param info The pass information. From 5a16a7cf3484f80dececb48cc44ee6132b79dc98 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 19 May 2021 12:30:42 +0800 Subject: [PATCH 15/24] Fix review comments --- python/tvm/ir/instrument.py | 4 ++-- src/ir/instrument.cc | 2 +- tests/python/relay/test_pass_instrument.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 9b1993a48303..069bf4ebf5ad 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -137,11 +137,11 @@ def create_pass_instrument(pi_cls): @tvm._ffi.register_object("instrument.PassInstrument") -class PassesTimeInstrument(tvm.runtime.Object): +class PassTimingInstrument(tvm.runtime.Object): """A wrapper to create a passes time instrument that implemented in C++""" def __init__(self): - self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassesTimeInstrument) + self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) @staticmethod def render(): diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index a22fb2a179e7..5594e7f86060 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -294,7 +294,7 @@ String RenderPassProfiles() { TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); -TVM_REGISTER_GLOBAL("instrument.MakePassesTimeInstrument").set_body_typed([]() { +TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { PassProfile::EnterPass(pass_info->name); return true; diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 84b16e73fea5..40092d8205b4 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -17,23 +17,23 @@ import tvm import tvm.relay from tvm.relay import op -from tvm.ir.instrument import PassesTimeInstrument, pass_instrument +from tvm.ir.instrument import PassTimingInstrument, pass_instrument -def test_pass_time_instrument(): +def test_pass_timing_instrument(): x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] e1 = op.add(x, y) e2 = op.subtract(x, z) e3 = op.multiply(e1, e1 / e2) mod = tvm.IRModule.from_expr(e3 + e2) - time_instrument = PassesTimeInstrument() - with tvm.transform.PassContext(instruments=[time_instrument]): + pass_timing = PassTimingInstrument() + with tvm.transform.PassContext(instruments=[pass_timing]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) mod = tvm.relay.transform.InferType()(mod) - profiles = time_instrument.render() + profiles = pass_timing.render() assert "AnnotateSpans" in profiles assert "ToANormalForm" in profiles assert "InferType" in profiles From 0f09405f95650736105d983d8fda6698e47eb2a6 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sat, 22 May 2021 22:02:58 +0800 Subject: [PATCH 16/24] Fix review comments --- include/tvm/ir/instrument.h | 84 +++-- include/tvm/ir/transform.h | 17 +- python/tvm/ir/instrument.py | 23 +- python/tvm/ir/transform.py | 26 +- src/ir/instrument.cc | 163 ++++++---- src/ir/transform.cc | 74 +++-- src/relay/ir/transform.cc | 10 - src/tir/ir/transform.cc | 5 - tests/python/relay/test_pass_annotation.py | 1 - tests/python/relay/test_pass_instrument.py | 360 ++++++++++++++++++--- tests/python/relay/test_pass_manager.py | 9 +- tutorials/dev/use_pass_infra.py | 1 - 12 files changed, 535 insertions(+), 238 deletions(-) diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 9f84dc4e33d1..1b0e9a9ea50e 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -47,66 +47,96 @@ namespace instrument { * \brief PassInstrumentNode forms an instrument implementation. * It provides API for users to register callbacks at different instrumentation points. * - * Within a pass context (tvm::transfom::PassContext), the instrumentation call sequence will like: + * Within a PassContext, call sequence of a PassInstrument implementation is like: * - * Instrument SetUp() + * with PassContext(instruments=[pi]): # pi = a PassInstrument implementation + * pi.EnterPassContext() * - * if (Instrument Before Pass()) - * Pass1() - * Instrument After Pass() + * if pi.ShouldRun(Pass1): + * pi.RunBeforePass() + * Pass1() + * pi.RunAfterPass() * - * if (Instrument Before Pass()) - * Pass2() - * Instrument After Pass() + * if pi.ShouldRun(Pass2): + * pi.RunBeforePass() + * Pass2() + * pi.RunAfterPass() * - * Instrument TearDown() + * pi.ExitPassContext() * + * `EnterPassContext` and `ExitPassContext` are only called once when entering/exiting a + * PassContext. `ShouldRun`, `RunBeforePass` and `RunAfterPass` are called multiple times depending + * on how many passes. * - * The `Before Pass` instrumentation point can selectively disable passes by returning true (to - * enable) or false (to disable). + * If there are multiple pass instrumentations provided, the instrument points are the same. + * PassInstrument implementations' callbacks are called in order: * - * If there are multiple pass instrumentations provided, `Before Pass` callbacks are applied in - * order. If one return false, then the pass will be skipped: + * with PassContext(instruments=[pi1, pi2]): # pi1, pi2 = two distinct PassInstrument impls + * pi.EnterPassContext() for pi in instruments * - * for (auto pi : PassInstruments) - * if (pi->BeforePass()) - * return False // Disable pass + * should_run = all([pi.ShoudRun(Pass1) for pi in instruments)]) + * if (should_run) + * pi.RunBeforePass() for pi in instruments + * Pass1() + * pi.RunAfterPass() for pi in instruments * - * return True // All ok, enable pass + * should_run = all([pi.ShouldRun(Pass2) for pi in instruments)]) + * if (should_run) + * pi.RunBeforePass() for pi in instruments + * Pass2() + * pi.RunAfterPass() for pi in instruments * + * pi.ExitPassContext() for pi in instruments + * + * Note: + * 1. Assume there is no dependency between PassInstrument implementations in `instruments` . + * 2. `EnterPassContext` and `ExitPassContext` have `with` behavior (see PassContext and its FFI): + * If there is any exception raised in `ShouldRun()`, `RunBeforePass()`, `RunAfterPass()` and + * `Pass()`, `ExitPassContext()` is still called. + * 3. In mutiple PassInstrument instances scenario, callbacks are called in order: + * If one throws exceptions, remainings will not be called. * * \sa PassInstrument - * \sa PassContextNode::InstrumentBeforePass() * \sa src/ir/transform.cc */ class PassInstrumentNode : public Object { public: + /*! \brief Name of this pass instrument object. */ + String name; + virtual ~PassInstrumentNode() {} - /*! \brief Set up environment for instrumentation. */ - virtual void SetUp() const = 0; + /*! \brief Instrument when entering PassContext. Called once within a PassContext. */ + virtual void EnterPassContext() const = 0; - /*! \brief Clean up instrumentation environment. */ - virtual void TearDown() const = 0; + /*! \brief Instrument when exiting PassContext. Called once within a PassContext. */ + virtual void ExitPassContext() const = 0; /*! - * \brief Instrument before pass run, determine whether to run the pass or not. + * \brief Determine whether to run the pass or not. Called multiple times depend on number of + * passes. * \param mod The module that an optimization pass runs on. * \param info The pass information. * * \return true to run the pass; false to skip the pass. */ - virtual bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; /*! - * \brief Instrument after pass run. - * + * \brief Instrument before pass run. Called multiple times depend on number of passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + + /*! + * \brief Instrument after pass run. Called multiple time depend on number of passes. * \param mod The module that an optimization pass runs on. * \param info The pass information. */ virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; - void VisitAttrs(AttrVisitor* v) {} + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } static constexpr const char* _type_key = "instrument.PassInstrument"; TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index ca25c18ed3f7..b773a0672269 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -127,6 +127,7 @@ class PassContextNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); } @@ -182,17 +183,20 @@ class PassContext : public ObjectRef { TVM_DLL static PassContext Current(); /*! - * \brief Set up for all the instrument implementations. + * \brief Call instrument implementations' callbacks when entering PassContex. + * In order, if one raises exceptions, remaings will not be called. */ - TVM_DLL void InstrumentSetUp() const; + TVM_DLL void InstrumentEnterPassContext() const; /*! - * \brief Clean up for all the instrument implementations. + * \brief Call instrument implementations' callback when exiting PassContext. + * In order, if one raises exceptions, remaings will not be called. */ - TVM_DLL void InstrumentTearDown() const; + TVM_DLL void InstrumentExitPassContext() const; /*! - * \brief Call intrument implementations before a pass run. + * \brief Call intrument implementations' callbacks before a pass run. + * In order, if one raises exceptions, remaings will not be called. * * \param mod The module that an optimization pass runs on. * \param info The pass information. @@ -202,7 +206,8 @@ class PassContext : public ObjectRef { TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; /*! - * \brief Call instrument implementations after a pass run. + * \brief Call instrument implementations callbacks after a pass run. + * In order, if one raises exceptions, remaings will not be called. * * \param mod The module that an optimization pass runs on. * \param info The pass information. diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 069bf4ebf5ad..c322f2bef3fc 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -45,7 +45,7 @@ class PyPassInstrument(PassInstrument): """Internal wrapper class to create a class instance.""" def __init__(self, *args, **kwargs): - # initialize handle in cass pi_cls creation failed.fg + # initialize handle in case pi_cls creation failed. self.handle = None inst = pi_cls(*args, **kwargs) @@ -61,15 +61,16 @@ def func(*args): return None # create runtime pass instrument object - # reister instance's run_before_pass, run_after_pass, set_up and tear_down method - # to it if present. + # reister instance's enter_pass_ctx,exit_pass_ctx, should_run, run_before_pass and + # run_after_pass methods to it if present. self.__init_handle_by_constructor__( - _ffi_instrument_api.NamedPassInstrument, + _ffi_instrument_api.PassInstrument, pi_cls.__name__, + create_method("enter_pass_ctx"), + create_method("exit_pass_ctx"), + create_method("should_run"), create_method("run_before_pass"), create_method("run_after_pass"), - create_method("set_up"), - create_method("tear_down"), ) self._inst = inst @@ -103,19 +104,23 @@ def __init__(self, skip_pass_name): self.skip_pass_name = skip_pass_name # Uncomment to customize - # def set_up(self): + # def enter_pass_ctx(self): # pass # Uncomment to customize - # def tear_down(self): + # def exit_pass_ctx(self): # pass # If pass name contains keyword, skip it by return False. (return True: not skip) - def run_before_pass(self, mod, pass_info): + def should_run(self, mod, pass_info) if self.skip_pass_name in pass_info.name: return False return True + # Uncomment to customize + # def run_before_pass(self, mod, pass_info): + # pass + # Uncomment to customize # def run_after_pass(self, mod, pass_info): # pass diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index c3490cc77108..f2fce8675b48 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -104,6 +104,9 @@ def __enter__(self): def __exit__(self, ptype, value, trace): _ffi_transform_api.ExitPassContext(self) + def override_instruments(self, instruments): + _ffi_transform_api.OverrideInstruments(self, instruments) + @staticmethod def current(): """Return the current pass context.""" @@ -343,26 +346,3 @@ def PrintIR(header="", show_meta_data=False): The pass """ return _ffi_transform_api.PrintIR(header, show_meta_data) - - -def render_pass_profiles(): - """Returns a string render of the pass profiling data. The format of each output line is - `{name}: {time} [{time excluding sub-passes}] ({% of total}; {% of parent})`. - The indentation of each line corresponds to nesting of passes. - """ - return _ffi_transform_api.render_pass_profiles() - - -def clear_pass_profiles(): - """Clears all stored pass profiling data.""" - _ffi_transform_api.clear_pass_profiles() - - -def enable_pass_profiling(): - """Enables pass profiling.""" - _ffi_transform_api.enable_pass_profiling() - - -def disable_pass_profiling(): - """Disables pass profiling.""" - _ffi_transform_api.disable_pass_profiling() diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 5594e7f86060..795e5b8cb542 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -33,42 +33,47 @@ namespace tvm { namespace instrument { /*! - * \brief A named PassInstrument implementation - * \sa NamedPassInstrument + * \brief Base PassInstrument implementation + * \sa BasePassInstrument */ -class NamedPassInstrumentNode : public PassInstrumentNode { +class BasePassInstrumentNode : public PassInstrumentNode { public: - /*! \brief Name of this pass instrument object. */ - String name; + /*! \brief Callback to run when entering PassContext. */ + runtime::TypedPackedFunc enter_pass_ctx_callback; + /*! \brief Callback to run when exiting PassContext. */ + runtime::TypedPackedFunc exit_pass_ctx_callback; - /*! \brief Callback for instrumentation environment set up. */ - runtime::TypedPackedFunc set_up_callback; - /*! \brief Callback for instrumentation environment clean up. */ - runtime::TypedPackedFunc tear_down_callback; + /*! \brief Callback determines whether to run a pass or not. */ + runtime::TypedPackedFunc should_run_callback; /*! \brief Callback to run before a pass. */ - runtime::TypedPackedFunc + runtime::TypedPackedFunc run_before_pass_callback; /*! \brief Callback to run after a pass. */ runtime::TypedPackedFunc run_after_pass_callback; - void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } - - /*! \brief Set up environment for instrumentation. */ - void SetUp() const final; + /*! \brief Instrument when entering PassContext. */ + void EnterPassContext() const final; - /*! \brief Clean up instrumentation environment. */ - void TearDown() const final; + /*! \brief Instrument when exiting PassContext. */ + void ExitPassContext() const final; /*! - * \brief Instrument before pass run, determine whether to run the pass or not. + * \brief Determine whether to run the pass or not. * \param mod The module that an optimization pass runs on. * \param info The pass information. * * \return true to run the pass; false to skip the pass. */ - bool RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final; + bool ShouldRun(const IRModule&, const transform::PassInfo& info) const final; + + /*! + * \brief Instrument before pass run. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final; /*! * \brief Instrument after pass run. @@ -78,98 +83,119 @@ class NamedPassInstrumentNode : public PassInstrumentNode { */ void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final; - static constexpr const char* _type_key = "instrument.NamedPassInstrument"; - TVM_DECLARE_FINAL_OBJECT_INFO(NamedPassInstrumentNode, PassInstrumentNode); + static constexpr const char* _type_key = "instrument.PassInstrument"; + TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode); }; /*! - * \brief Managed reference class for NamedPassInstrumentNode - * \sa NamedPassInstrumentNode + * \brief Managed reference class for BasePassInstrumentNode + * \sa BasePassInstrumentNode */ -class NamedPassInstrument : public PassInstrument { +class BasePassInstrument : public PassInstrument { public: /*! * \brief Constructor + * * \param name Name for this instrumentation. + * + * + * \param enter_pass_ctx_callback Callback to call when entering pass context. + * \param exit_pass_ctx_callback Callback to call when exiting pass context. + * + * \param should_run_callback Callback to determine whether pass should run. (return true: enable; + * return false: disable) + * * \param run_before_pass_callback Callback to call before a pass run. * \param run_after_pass_callback Callback to call after a pass run. - * \param set_up_callback Callback to call when entering pass context. - * \param tear_down_callback Callback to call when exiting pass context. */ - TVM_DLL NamedPassInstrument( - String name, + TVM_DLL BasePassInstrument( + String name, runtime::TypedPackedFunc enter_pass_ctx_callback, + runtime::TypedPackedFunc exit_pass_ctx_callback, runtime::TypedPackedFunc + should_run_callback, + runtime::TypedPackedFunc run_before_pass_callback, runtime::TypedPackedFunc - run_after_pass_callback, - runtime::TypedPackedFunc set_up_callback, - runtime::TypedPackedFunc tear_down_callback); + run_after_pass_callback); - TVM_DEFINE_OBJECT_REF_METHODS(NamedPassInstrument, PassInstrument, NamedPassInstrumentNode); + TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode); }; -NamedPassInstrument::NamedPassInstrument( - String name, - runtime::TypedPackedFunc +BasePassInstrument::BasePassInstrument( + String name, runtime::TypedPackedFunc enter_pass_ctx_callback, + runtime::TypedPackedFunc exit_pass_ctx_callback, + runtime::TypedPackedFunc should_run_callback, + runtime::TypedPackedFunc run_before_pass_callback, runtime::TypedPackedFunc - run_after_pass_callback, - runtime::TypedPackedFunc set_up_callback, - runtime::TypedPackedFunc tear_down_callback) { - auto pi = make_object(); + run_after_pass_callback) { + auto pi = make_object(); pi->name = std::move(name); + + pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback); + pi->exit_pass_ctx_callback = std::move(exit_pass_ctx_callback); + + pi->should_run_callback = std::move(should_run_callback); + pi->run_before_pass_callback = std::move(run_before_pass_callback); pi->run_after_pass_callback = std::move(run_after_pass_callback); - pi->set_up_callback = std::move(set_up_callback); - pi->tear_down_callback = std::move(tear_down_callback); data_ = std::move(pi); } -void NamedPassInstrumentNode::SetUp() const { - if (set_up_callback != nullptr) { - set_up_callback(); +void BasePassInstrumentNode::EnterPassContext() const { + if (enter_pass_ctx_callback != nullptr) { + enter_pass_ctx_callback(); } } -void NamedPassInstrumentNode::TearDown() const { - if (tear_down_callback != nullptr) { - tear_down_callback(); +void BasePassInstrumentNode::ExitPassContext() const { + if (exit_pass_ctx_callback != nullptr) { + exit_pass_ctx_callback(); } } -bool NamedPassInstrumentNode::RunBeforePass(const IRModule& ir_module, - const transform::PassInfo& pass_info) const { - if (run_before_pass_callback == nullptr) { +bool BasePassInstrumentNode::ShouldRun(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (should_run_callback == nullptr) { return true; } - return run_before_pass_callback(ir_module, pass_info); + return should_run_callback(ir_module, pass_info); } -void NamedPassInstrumentNode::RunAfterPass(const IRModule& ir_module, +void BasePassInstrumentNode::RunBeforePass(const IRModule& ir_module, const transform::PassInfo& pass_info) const { + if (run_before_pass_callback != nullptr) { + run_before_pass_callback(ir_module, pass_info); + } +} + +void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { if (run_after_pass_callback != nullptr) { run_after_pass_callback(ir_module, pass_info); } } -TVM_REGISTER_NODE_TYPE(NamedPassInstrumentNode); - -TVM_REGISTER_GLOBAL("instrument.NamedPassInstrument") - .set_body_typed([](String name, - runtime::TypedPackedFunc - run_before_pass, - runtime::TypedPackedFunc - run_after_pass, - runtime::TypedPackedFunc set_up, - runtime::TypedPackedFunc tear_down) { - return NamedPassInstrument(name, run_before_pass, run_after_pass, set_up, tear_down); - }); +TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); + +TVM_REGISTER_GLOBAL("instrument.PassInstrument") + .set_body_typed( + [](String name, runtime::TypedPackedFunc enter_pass_ctx, + runtime::TypedPackedFunc exit_pass_ctx, + runtime::TypedPackedFunc should_run, + runtime::TypedPackedFunc + run_before_pass, + runtime::TypedPackedFunc + run_after_pass) { + return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, + run_before_pass, run_after_pass); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); p->stream << node->name; }); @@ -304,10 +330,11 @@ TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { PassProfile::ExitPass(); }; - auto tear_down = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; + auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; - return NamedPassInstrument("PassesTimeInstrument", run_before_pass, run_after_pass, - /* set_up */ nullptr, tear_down); + return BasePassInstrument("PassTimingInstrument", + /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr, + run_before_pass, run_after_pass); }); } // namespace instrument diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 886c91ad2d6b..11b7d8de7ee5 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -58,14 +59,14 @@ typedef dmlc::ThreadLocalStore RelayPassContextThre void PassContext::EnterWithScope() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); - InstrumentSetUp(); + InstrumentEnterPassContext(); } void PassContext::ExitWithScope() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); - InstrumentTearDown(); + InstrumentExitPassContext(); entry->context_stack.pop(); } @@ -164,34 +165,44 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassContext PassContext::Create() { return PassContext(make_object()); } -void PassContext::InstrumentSetUp() const { +void PassContext::InstrumentEnterPassContext() const { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - pi->SetUp(); + pi->EnterPassContext(); } } } -void PassContext::InstrumentTearDown() const { +void PassContext::InstrumentExitPassContext() const { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - pi->TearDown(); + pi->ExitPassContext(); } } } bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->instruments.defined()) { + if (!pass_ctx_node->instruments.defined()) { + return true; + } + + const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name); + bool should_run = true; + if (!pass_required) { + const Array& instruments = pass_ctx_node->instruments; + should_run &= std::all_of(instruments.begin(), instruments.end(), + [&](auto pi) { return pi->ShouldRun(ir_module, pass_info); }); + } + + if (should_run) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - if (!pi->RunBeforePass(ir_module, pass_info)) { - return false; - } + pi->RunBeforePass(ir_module, pass_info); } } - return true; + return should_run; } void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { @@ -204,16 +215,20 @@ void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& } IRModule Pass::operator()(IRModule mod) const { - const PassNode* node = operator->(); - ICHECK(node != nullptr); - auto ret = node->operator()(std::move(mod)); - return std::move(ret); + return this->operator()(mod, PassContext::Current()); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); ICHECK(node != nullptr); + const PassInfo& pass_info = node->Info(); + if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { + DLOG(INFO) << "Skipping pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; + return mod; + } auto ret = node->operator()(std::move(mod), pass_ctx); + pass_ctx.InstrumentAfterPass(ret, pass_info); return std::move(ret); } @@ -357,14 +372,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c const PassInfo& pass_info = Info(); ICHECK(mod.defined()) << "The input module must be set."; - if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { - DLOG(INFO) << "Skipping module pass : " << pass_info->name - << " with opt level: " << pass_info->opt_level; - - pass_ctx->diag_ctx = previous; - return mod; - } - DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; @@ -378,7 +385,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.InstrumentAfterPass(mod, pass_info); return mod; } @@ -541,17 +547,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\trequired passes: ["; - for (const auto& it : node->required_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; + p->stream << "\trequired passes: " << node->required_pass << "\n"; + p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; + p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tdisabled passes: ["; - for (const auto& it : node->disabled_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; p->stream << "\tconfig: " << node->config; }); @@ -568,6 +567,13 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("transform.OverrideInstruments") + .set_body_typed([](PassContext pass_ctx, Array instruments) { + pass_ctx.InstrumentExitPassContext(); + pass_ctx->instruments = instruments; + pass_ctx.InstrumentEnterPassContext(); + }); + Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 999c9c4fe39e..4a7974cae5ae 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -130,14 +130,6 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) ICHECK(mod.defined()); - if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { - DLOG(INFO) << "Skipping function pass : " << pass_info->name - << " with opt level: " << pass_info->opt_level; - - pass_ctx->diag_ctx = previous; - return mod; - } - DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; @@ -165,8 +157,6 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.InstrumentAfterPass(updated_mod, pass_info); - // TODO(@jroesch): move away from eager type checking for performance reasons // make issue. return transform::InferType()(updated_mod); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 7f1deae60925..4c59a1767372 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -87,11 +87,7 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - const PassInfo& pass_info = Info(); ICHECK(mod.defined()); - if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { - return mod; - } std::vector deleted_list; IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); @@ -114,7 +110,6 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& gv : deleted_list) { func_dict->erase(gv); } - pass_ctx.InstrumentAfterPass(mod, pass_info); return mod; } diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 8306c70a66bd..6f68cba268cf 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -32,7 +32,6 @@ class Trace: def run_before_pass(self, module, pass_info): if pass_info.name == "ManifestAlloc": pass # import pdb; pdb.set_trace() - return True def run_after_pass(self, module, pass_info): if pass_info.name == "ManifestAlloc": diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 40092d8205b4..9c0345057e04 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -14,77 +14,91 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" Instrument test cases. +""" +import pytest import tvm import tvm.relay from tvm.relay import op from tvm.ir.instrument import PassTimingInstrument, pass_instrument -def test_pass_timing_instrument(): +def get_test_model(): x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] e1 = op.add(x, y) e2 = op.subtract(x, z) e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) + return tvm.IRModule.from_expr(e3 + e2) + +def test_pass_timing_instrument(): pass_timing = PassTimingInstrument() - with tvm.transform.PassContext(instruments=[pass_timing]): - mod = tvm.relay.transform.AnnotateSpans()(mod) - mod = tvm.relay.transform.ToANormalForm()(mod) - mod = tvm.relay.transform.InferType()(mod) - profiles = pass_timing.render() - assert "AnnotateSpans" in profiles - assert "ToANormalForm" in profiles - assert "InferType" in profiles + # Override current PassContext's instruments + tvm.transform.PassContext.current().override_instruments([pass_timing]) + mod = get_test_model() + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) -def test_custom_instrument(capsys): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) + profiles = pass_timing.render() + assert "AnnotateSpans" in profiles + assert "ToANormalForm" in profiles + assert "InferType" in profiles + + # Reset current PassContext's instruments to None + tvm.transform.PassContext.current().override_instruments(None) + + mod = get_test_model() + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = pass_timing.render() + assert profiles == "" + +def test_custom_instrument(capsys): @pass_instrument class MyTest: - def set_up(self): - print("set up") + def enter_pass_ctx(self): + print("enter ctx") - def tear_down(self): - print("tear down") + def exit_pass_ctx(self): + print("exit ctx") def run_before_pass(self, mod, info): print("run before " + info.name) - return True def run_after_pass(self, mod, info): print("run after " + info.name) + mod = get_test_model() with tvm.transform.PassContext(instruments=[MyTest()]): mod = tvm.relay.transform.InferType()(mod) - output = "set up\n" "run before InferType\n" "run after InferType\n" "tear down\n" - assert capsys.readouterr().out == output + assert ( + "enter ctx\n" + "run before InferType\n" + "run after InferType\n" + "exit ctx\n" == capsys.readouterr().out + ) def test_disable_pass(capsys): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) - @pass_instrument class CustomPI: - def run_before_pass(self, mod, info): + def should_run(self, mod, info): # Only run pass name contains "InferType" if "InferType" not in info.name: return False + return True + def run_before_pass(self, mod, info): print(info.name) - return True + mod = get_test_model() with tvm.transform.PassContext(instruments=[CustomPI()]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) @@ -94,18 +108,12 @@ def run_before_pass(self, mod, info): def test_multiple_instrument(capsys): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) - @pass_instrument class SkipPass: def __init__(self, skip_pass_name): self.skip_pass_name = skip_pass_name - def run_before_pass(self, mod, info): + def should_run(self, mod, info): if self.skip_pass_name in info.name: return False return True @@ -117,10 +125,9 @@ def run_before_pass(self, mod, info): class PrintPassName: def run_before_pass(self, mod, info): print(info.name) - return True + mod = get_test_model() print_pass_name = PrintPassName() - with tvm.transform.PassContext(instruments=[skip_annotate, skip_anf, print_pass_name]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) @@ -130,12 +137,6 @@ def run_before_pass(self, mod, info): def test_instrument_pass_counts(capsys): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) - @pass_instrument class PassesCounter: def __init__(self): @@ -146,19 +147,19 @@ def __clear(self): self.run_before_count = 0 self.run_after_count = 0 - def set_up(self): + def enter_pass_ctx(self): self.__clear() - def tear_down(self): + def exit_pass_ctx(self): self.__clear() def run_before_pass(self, mod, info): self.run_before_count = self.run_before_count + 1 - return True def run_after_pass(self, mod, info): self.run_after_count = self.run_after_count + 1 + mod = get_test_model() passes_counter = PassesCounter() with tvm.transform.PassContext(instruments=[passes_counter]): tvm.relay.build(mod, "llvm") @@ -168,3 +169,264 @@ def run_after_pass(self, mod, info): # Out of pass context scope, should be reset assert passes_counter.run_before_count == 0 assert passes_counter.run_after_count == 0 + + +def test_enter_pass_ctx_expection(capsys): + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + print(self.id + " enter ctx") + + def exit_pass_ctx(self): + print(self.id + " exit ctx") + + @pass_instrument + class PIBroken(PI): + def __init__(self, id): + super().__init__(id) + + def enter_pass_ctx(self): + print(self.id + " enter ctx") + raise RuntimeError("Just a dummy error") + + with pytest.raises(tvm.error.TVMError): + with tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]): + pass + + assert "%1 enter ctx\n" "%2 enter ctx\n" == capsys.readouterr().out + + +def test_pass_exception(capsys): + @pass_instrument + class PI: + def enter_pass_ctx(self): + print("enter_pass_ctx") + + def exit_pass_ctx(self): + print("exit_pass_ctx") + + def should_run(self, mod, info): + print("should_run") + return True + + def run_before_pass(self, mod, info): + print("run_before_pass") + + def run_after_pass(self, mod, info): + print("run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + print("transform pass") + raise RuntimeError("Just a dummy error") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError): + with tvm.transform.PassContext(instruments=[PI()]): + mod = transform(mod) + + assert ( + "enter_pass_ctx\n" + "should_run\n" + "run_before_pass\n" + "transform pass\n" + "exit_pass_ctx\n" == capsys.readouterr().out + ) + + +def test_should_run_exception(capsys): + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + print(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + print(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + print(self.id + " should_run") + raise RuntimeError("Just a dummy error") + return True + + def run_before_pass(self, mod, info): + print(self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + print(self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + print("transform pass") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError): + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + + assert ( + "%1 enter_pass_ctx\n" + "%2 enter_pass_ctx\n" + "%1 should_run\n" + "%1 exit_pass_ctx\n" + "%2 exit_pass_ctx\n" == capsys.readouterr().out + ) + + +def test_run_before_exception(capsys): + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + print(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + print(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + print(self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + print(self.id + " run_before_pass") + raise RuntimeError("Just a dummy error") + + def run_after_pass(self, mod, info): + print(self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + print("transform pass") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError): + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + + assert ( + "%1 enter_pass_ctx\n" + "%2 enter_pass_ctx\n" + "%1 should_run\n" + "%2 should_run\n" + "%1 run_before_pass\n" + "%1 exit_pass_ctx\n" + "%2 exit_pass_ctx\n" == capsys.readouterr().out + ) + + +def test_run_after_exception(capsys): + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + print(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + print(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + print(self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + print(self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + print(self.id + " run_after_pass") + raise RuntimeError("Just a dummy error") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + print("transform pass") + return mod + + x, y = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xy"] + mod = tvm.IRModule.from_expr(tvm.relay.add(x, y)) + + with pytest.raises(tvm.error.TVMError): + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + + assert ( + "%1 enter_pass_ctx\n" + "%2 enter_pass_ctx\n" + "%1 should_run\n" + "%2 should_run\n" + "%1 run_before_pass\n" + "%2 run_before_pass\n" + "transform pass\n" + "%1 run_after_pass\n" + "%1 exit_pass_ctx\n" + "%2 exit_pass_ctx\n" == capsys.readouterr().out + ) + + +def test_instrument_call_sequence(capsys): + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + print(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + print(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + print(" " + self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + print(" " + self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + print(" " + self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform1(mod, ctx): + print(" transform1 pass") + return mod + + @tvm.transform.module_pass(opt_level=2) + def transform2(mod, ctx): + print(" transform2 pass") + return mod + + mod = get_test_model() + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform1(mod) + mod = transform2(mod) + + assert ( + "%1 enter_pass_ctx\n" + "%2 enter_pass_ctx\n" + " %1 should_run\n" + " %2 should_run\n" + " %1 run_before_pass\n" + " %2 run_before_pass\n" + " transform1 pass\n" + " %1 run_after_pass\n" + " %2 run_after_pass\n" + " %1 should_run\n" + " %2 should_run\n" + " %1 run_before_pass\n" + " %2 run_before_pass\n" + " transform2 pass\n" + " %1 run_after_pass\n" + " %2 run_after_pass\n" + "%1 exit_pass_ctx\n" + "%2 exit_pass_ctx\n" == capsys.readouterr().out + ) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 1fc722aab6da..ee889d857cee 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -540,15 +540,14 @@ def __init__(self): # Just setting a garbage value to test set_up callback self.counts = 1234 - def set_up(self): + def enter_pass_ctx(self): self.counts = 0 - def tear_down(self): + def exit_pass_ctx(self): self.counts = 0 def run_before_pass(self, module, info): self.counts += 1 - return True def get_counts(self): return self.counts @@ -579,8 +578,8 @@ def test_print_debug_callback(): mod = seq(mod) # TODO(@jroesch): when we remove new fn pass behavior we need to remove - # change this back to 3 - assert pass_counter.get_counts() == 5 + # change this back to match correct behavior + assert pass_counter.get_counts() == 6 # Should be cleanned up after exiting pass context assert pass_counter.get_counts() == 0 diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 7c7c4af23bfa..3804b1496d05 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -280,7 +280,6 @@ class PrintIR: def run_before_pass(self, mod, info): print("Running pass: {}", info) print(mod) - return True with tvm.transform.PassContext(opt_level=3, instruments=[PrintIR()]): From 2162911f9fcf0549467ef0939ab6f2a2bdb267cb Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 23 May 2021 01:05:00 +0800 Subject: [PATCH 17/24] Fix unittest error: test_cow_pass --- src/ir/transform.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 11b7d8de7ee5..f6c776656b5b 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -215,7 +215,7 @@ void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& } IRModule Pass::operator()(IRModule mod) const { - return this->operator()(mod, PassContext::Current()); + return this->operator()(std::move(mod), PassContext::Current()); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { From b2de46c458338b2dc3105b71f0c5431651e36966 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Sun, 23 May 2021 12:15:05 +0800 Subject: [PATCH 18/24] Fix unittest error --- include/tvm/ir/transform.h | 4 +- src/ir/transform.cc | 29 ++++++++++---- tests/python/relay/test_pass_instrument.py | 44 +++++++++++++++++++++- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index b773a0672269..52523e84e6e8 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -186,13 +186,13 @@ class PassContext : public ObjectRef { * \brief Call instrument implementations' callbacks when entering PassContex. * In order, if one raises exceptions, remaings will not be called. */ - TVM_DLL void InstrumentEnterPassContext() const; + TVM_DLL void InstrumentEnterPassContext(); /*! * \brief Call instrument implementations' callback when exiting PassContext. * In order, if one raises exceptions, remaings will not be called. */ - TVM_DLL void InstrumentExitPassContext() const; + TVM_DLL void InstrumentExitPassContext(); /*! * \brief Call intrument implementations' callbacks before a pass run. diff --git a/src/ir/transform.cc b/src/ir/transform.cc index f6c776656b5b..b107b7d247c1 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -57,17 +57,19 @@ struct PassContextThreadLocalEntry { typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { + InstrumentEnterPassContext(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); - InstrumentEnterPassContext(); } void PassContext::ExitWithScope() { PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); - InstrumentExitPassContext(); entry->context_stack.pop(); + + InstrumentExitPassContext(); } PassContext PassContext::Current() { @@ -165,20 +167,31 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassContext PassContext::Create() { return PassContext(make_object()); } -void PassContext::InstrumentEnterPassContext() const { +void PassContext::InstrumentEnterPassContext() { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { - for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - pi->EnterPassContext(); + try { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->EnterPassContext(); + } + } catch (const Error& e) { + LOG(INFO) << "Pass instrumentation entering pass context failed."; + LOG(INFO) << "Disable pass instrumentation."; + throw e; } } } -void PassContext::InstrumentExitPassContext() const { +void PassContext::InstrumentExitPassContext() { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { - for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - pi->ExitPassContext(); + try { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->ExitPassContext(); + } + } catch (const Error& e) { + LOG(INFO) << "Pass instrumentation exiting pass context failed."; + throw e; } } } diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 9c0345057e04..e7e473fec7df 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -171,7 +171,7 @@ def run_after_pass(self, mod, info): assert passes_counter.run_after_count == 0 -def test_enter_pass_ctx_expection(capsys): +def test_enter_pass_ctx_exception(capsys): @pass_instrument class PI: def __init__(self, id): @@ -192,12 +192,52 @@ def enter_pass_ctx(self): print(self.id + " enter ctx") raise RuntimeError("Just a dummy error") + pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) with pytest.raises(tvm.error.TVMError): - with tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]): + with pass_ctx: pass assert "%1 enter ctx\n" "%2 enter ctx\n" == capsys.readouterr().out + # Make sure we get correct PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + assert pass_ctx != cur_pass_ctx + assert cur_pass_ctx.instruments == None + + +def test_exit_pass_ctx_exception(capsys): + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def exit_pass_ctx(self): + print(self.id + " exit ctx") + + def exit_pass_ctx(self): + print(self.id + " exit ctx") + + @pass_instrument + class PIBroken(PI): + def __init__(self, id): + super().__init__(id) + + def exit_pass_ctx(self): + print(self.id + " exit ctx") + raise RuntimeError("Just a dummy error") + + pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) + with pytest.raises(tvm.error.TVMError): + with pass_ctx: + pass + + assert "%1 exit ctx\n" "%2 exit ctx\n" == capsys.readouterr().out + + # Make sure we get correct PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + assert pass_ctx != cur_pass_ctx + assert cur_pass_ctx.instruments == None + def test_pass_exception(capsys): @pass_instrument From 06ee84d8260bf9f9677755ceaae0892da93628bd Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Mon, 24 May 2021 00:05:32 +0800 Subject: [PATCH 19/24] Add more test cases for exceptions --- src/ir/transform.cc | 2 ++ tests/python/relay/test_pass_instrument.py | 27 +++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index b107b7d247c1..396d36dd9d2e 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -177,6 +177,7 @@ void PassContext::InstrumentEnterPassContext() { } catch (const Error& e) { LOG(INFO) << "Pass instrumentation entering pass context failed."; LOG(INFO) << "Disable pass instrumentation."; + pass_ctx_node->instruments.clear(); throw e; } } @@ -191,6 +192,7 @@ void PassContext::InstrumentExitPassContext() { } } catch (const Error& e) { LOG(INFO) << "Pass instrumentation exiting pass context failed."; + pass_ctx_node->instruments.clear(); throw e; } } diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index e7e473fec7df..12287fcc5a60 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -205,6 +205,18 @@ def enter_pass_ctx(self): assert cur_pass_ctx.instruments == None +def test_enter_pass_ctx_exception_global(capsys): + @pass_instrument + class PIBroken: + def enter_pass_ctx(self): + raise RuntimeError("Just a dummy error") + + cur_pass_ctx = tvm.transform.PassContext.current() + with pytest.raises(tvm.error.TVMError): + cur_pass_ctx.override_instruments([PIBroken()]) + assert not cur_pass_ctx.instruments + + def test_exit_pass_ctx_exception(capsys): @pass_instrument class PI: @@ -236,7 +248,20 @@ def exit_pass_ctx(self): # Make sure we get correct PassContext cur_pass_ctx = tvm.transform.PassContext.current() assert pass_ctx != cur_pass_ctx - assert cur_pass_ctx.instruments == None + assert not cur_pass_ctx.instruments + + +def test_exit_pass_ctx_exception_global(capsys): + @pass_instrument + class PIBroken: + def exit_pass_ctx(self): + raise RuntimeError("Just a dummy error") + + cur_pass_ctx = tvm.transform.PassContext.current() + with pytest.raises(tvm.error.TVMError): + cur_pass_ctx.override_instruments([PIBroken()]) + cur_pass_ctx.override_instruments([PIBroken()]) + assert not cur_pass_ctx.instruments def test_pass_exception(capsys): From 58c456d579bf24a11e478d5f7c23799e4aa3b5d5 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Tue, 25 May 2021 09:31:36 +0800 Subject: [PATCH 20/24] Fix nit --- include/tvm/ir/transform.h | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 52523e84e6e8..849eda6cd248 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -183,20 +183,23 @@ class PassContext : public ObjectRef { TVM_DLL static PassContext Current(); /*! - * \brief Call instrument implementations' callbacks when entering PassContex. - * In order, if one raises exceptions, remaings will not be called. + * \brief Call instrument implementations' callbacks when entering PassContext. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. */ TVM_DLL void InstrumentEnterPassContext(); /*! - * \brief Call instrument implementations' callback when exiting PassContext. - * In order, if one raises exceptions, remaings will not be called. + * \brief Call instrument implementations' callbacks when exiting PassContext. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. */ TVM_DLL void InstrumentExitPassContext(); /*! - * \brief Call intrument implementations' callbacks before a pass run. - * In order, if one raises exceptions, remaings will not be called. + * \brief Call instrument implementations' callbacks before a pass run. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. * * \param mod The module that an optimization pass runs on. * \param info The pass information. @@ -207,7 +210,8 @@ class PassContext : public ObjectRef { /*! * \brief Call instrument implementations callbacks after a pass run. - * In order, if one raises exceptions, remaings will not be called. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. * * \param mod The module that an optimization pass runs on. * \param info The pass information. From 0dd603f2caef98c548332258a4b9ad2177c97cdf Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Tue, 25 May 2021 09:56:40 +0800 Subject: [PATCH 21/24] Doc override_instruments() --- python/tvm/ir/transform.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index f2fce8675b48..3a3ac16be677 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -105,6 +105,14 @@ def __exit__(self, ptype, value, trace): _ffi_transform_api.ExitPassContext(self) def override_instruments(self, instruments): + """Override instruments within this PassContext. + + If there are existing instruments, their exit_pass_ctx callbacks are called. + Then switching to new instruments and calling new enter_pass_ctx callbacks. + + instruments : Sequence[PassInstrument] + The list of pass instrument implementations. + """ _ffi_transform_api.OverrideInstruments(self, instruments) @staticmethod From 7c504d82ac6ed8a89198815abd5fefda6babb72d Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 26 May 2021 01:41:19 +0800 Subject: [PATCH 22/24] Fix review comments --- src/ir/transform.cc | 7 +- tests/python/relay/test_pass_instrument.py | 271 ++++++++++++--------- 2 files changed, 155 insertions(+), 123 deletions(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 396d36dd9d2e..a8f6ced5bdd6 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,7 +28,6 @@ #include #include -#include #include #include #include @@ -207,9 +206,9 @@ bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name); bool should_run = true; if (!pass_required) { - const Array& instruments = pass_ctx_node->instruments; - should_run &= std::all_of(instruments.begin(), instruments.end(), - [&](auto pi) { return pi->ShouldRun(ir_module, pass_info); }); + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + should_run &= pi->ShouldRun(ir_module, pass_info); + } } if (should_run) { diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 12287fcc5a60..740075d7a789 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -59,36 +59,43 @@ def test_pass_timing_instrument(): assert profiles == "" -def test_custom_instrument(capsys): +def test_custom_instrument(): @pass_instrument class MyTest: + def __init__(self): + self.events = [] + def enter_pass_ctx(self): - print("enter ctx") + self.events.append("enter ctx") def exit_pass_ctx(self): - print("exit ctx") + self.events.append("exit ctx") def run_before_pass(self, mod, info): - print("run before " + info.name) + self.events.append("run before " + info.name) def run_after_pass(self, mod, info): - print("run after " + info.name) + self.events.append("run after " + info.name) mod = get_test_model() - with tvm.transform.PassContext(instruments=[MyTest()]): + my_test = MyTest() + with tvm.transform.PassContext(instruments=[my_test]): mod = tvm.relay.transform.InferType()(mod) assert ( - "enter ctx\n" - "run before InferType\n" - "run after InferType\n" - "exit ctx\n" == capsys.readouterr().out + "enter ctx" + "run before InferType" + "run after InferType" + "exit ctx" == "".join(my_test.events) ) -def test_disable_pass(capsys): +def test_disable_pass(): @pass_instrument class CustomPI: + def __init__(self): + self.events = [] + def should_run(self, mod, info): # Only run pass name contains "InferType" if "InferType" not in info.name: @@ -96,18 +103,19 @@ def should_run(self, mod, info): return True def run_before_pass(self, mod, info): - print(info.name) + self.events.append(info.name) mod = get_test_model() - with tvm.transform.PassContext(instruments=[CustomPI()]): + custom_pi = CustomPI() + with tvm.transform.PassContext(instruments=[custom_pi]): mod = tvm.relay.transform.AnnotateSpans()(mod) mod = tvm.relay.transform.ToANormalForm()(mod) mod = tvm.relay.transform.InferType()(mod) - assert capsys.readouterr().out == "InferType\n" + assert "InferType" == "".join(custom_pi.events) -def test_multiple_instrument(capsys): +def test_multiple_instrument(): @pass_instrument class SkipPass: def __init__(self, skip_pass_name): @@ -123,8 +131,11 @@ def should_run(self, mod, info): @pass_instrument class PrintPassName: + def __init__(self): + self.events = [] + def run_before_pass(self, mod, info): - print(info.name) + self.events.append(info.name) mod = get_test_model() print_pass_name = PrintPassName() @@ -133,10 +144,10 @@ def run_before_pass(self, mod, info): mod = tvm.relay.transform.ToANormalForm()(mod) mod = tvm.relay.transform.InferType()(mod) - assert capsys.readouterr().out == "InferType\n" + assert "InferType" == "".join(print_pass_name.events) -def test_instrument_pass_counts(capsys): +def test_instrument_pass_counts(): @pass_instrument class PassesCounter: def __init__(self): @@ -171,17 +182,19 @@ def run_after_pass(self, mod, info): assert passes_counter.run_after_count == 0 -def test_enter_pass_ctx_exception(capsys): +def test_enter_pass_ctx_exception(): + events = [] + @pass_instrument class PI: def __init__(self, id): self.id = id def enter_pass_ctx(self): - print(self.id + " enter ctx") + events.append(self.id + " enter ctx") def exit_pass_ctx(self): - print(self.id + " exit ctx") + events.append(self.id + " exit ctx") @pass_instrument class PIBroken(PI): @@ -189,15 +202,16 @@ def __init__(self, id): super().__init__(id) def enter_pass_ctx(self): - print(self.id + " enter ctx") + events.append(self.id + " enter ctx") raise RuntimeError("Just a dummy error") pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: with pass_ctx: pass + assert "Just a dummy error" in str(cm.execption) - assert "%1 enter ctx\n" "%2 enter ctx\n" == capsys.readouterr().out + assert "%1 enter ctx" "%2 enter ctx" == "".join(events) # Make sure we get correct PassContext cur_pass_ctx = tvm.transform.PassContext.current() @@ -205,29 +219,32 @@ def enter_pass_ctx(self): assert cur_pass_ctx.instruments == None -def test_enter_pass_ctx_exception_global(capsys): +def test_enter_pass_ctx_exception_global(): @pass_instrument class PIBroken: def enter_pass_ctx(self): raise RuntimeError("Just a dummy error") cur_pass_ctx = tvm.transform.PassContext.current() - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: cur_pass_ctx.override_instruments([PIBroken()]) + assert "Just a dummy error" in str(cm.exception) assert not cur_pass_ctx.instruments -def test_exit_pass_ctx_exception(capsys): +def test_exit_pass_ctx_exception(): + events = [] + @pass_instrument class PI: def __init__(self, id): self.id = id def exit_pass_ctx(self): - print(self.id + " exit ctx") + events.append(self.id + " exit ctx") def exit_pass_ctx(self): - print(self.id + " exit ctx") + events.append(self.id + " exit ctx") @pass_instrument class PIBroken(PI): @@ -235,15 +252,16 @@ def __init__(self, id): super().__init__(id) def exit_pass_ctx(self): - print(self.id + " exit ctx") + events.append(self.id + " exit ctx") raise RuntimeError("Just a dummy error") pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: with pass_ctx: pass + assert "Just a dummy error" in str(cm.exception) - assert "%1 exit ctx\n" "%2 exit ctx\n" == capsys.readouterr().out + assert "%1 exit ctx" "%2 exit ctx" == "".join(events) # Make sure we get correct PassContext cur_pass_ctx = tvm.transform.PassContext.current() @@ -251,223 +269,238 @@ def exit_pass_ctx(self): assert not cur_pass_ctx.instruments -def test_exit_pass_ctx_exception_global(capsys): +def test_exit_pass_ctx_exception_global(): @pass_instrument class PIBroken: def exit_pass_ctx(self): raise RuntimeError("Just a dummy error") cur_pass_ctx = tvm.transform.PassContext.current() - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: cur_pass_ctx.override_instruments([PIBroken()]) cur_pass_ctx.override_instruments([PIBroken()]) + assert "Just a dummy error" in str(cm.exception) assert not cur_pass_ctx.instruments -def test_pass_exception(capsys): +def test_pass_exception(): + events = [] + @pass_instrument class PI: def enter_pass_ctx(self): - print("enter_pass_ctx") + events.append("enter_pass_ctx") def exit_pass_ctx(self): - print("exit_pass_ctx") + events.append("exit_pass_ctx") def should_run(self, mod, info): - print("should_run") + events.append("should_run") return True def run_before_pass(self, mod, info): - print("run_before_pass") + events.append("run_before_pass") def run_after_pass(self, mod, info): - print("run_after_pass") + events.append("run_after_pass") @tvm.transform.module_pass(opt_level=2) def transform(mod, ctx): - print("transform pass") + events.append("transform pass") raise RuntimeError("Just a dummy error") return mod mod = get_test_model() - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: with tvm.transform.PassContext(instruments=[PI()]): mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) assert ( - "enter_pass_ctx\n" - "should_run\n" - "run_before_pass\n" - "transform pass\n" - "exit_pass_ctx\n" == capsys.readouterr().out + "enter_pass_ctx" + "should_run" + "run_before_pass" + "transform pass" + "exit_pass_ctx" == "".join(events) ) -def test_should_run_exception(capsys): +def test_should_run_exception(): + events = [] + @pass_instrument class PI: def __init__(self, id): self.id = id def enter_pass_ctx(self): - print(self.id + " enter_pass_ctx") + events.append(self.id + " enter_pass_ctx") def exit_pass_ctx(self): - print(self.id + " exit_pass_ctx") + events.append(self.id + " exit_pass_ctx") def should_run(self, mod, info): - print(self.id + " should_run") + events.append(self.id + " should_run") raise RuntimeError("Just a dummy error") return True def run_before_pass(self, mod, info): - print(self.id + " run_before_pass") + events.append(self.id + " run_before_pass") def run_after_pass(self, mod, info): - print(self.id + " run_after_pass") + events.append(self.id + " run_after_pass") @tvm.transform.module_pass(opt_level=2) def transform(mod, ctx): - print("transform pass") + events.append("transform pass") return mod mod = get_test_model() - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) assert ( - "%1 enter_pass_ctx\n" - "%2 enter_pass_ctx\n" - "%1 should_run\n" - "%1 exit_pass_ctx\n" - "%2 exit_pass_ctx\n" == capsys.readouterr().out + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) ) -def test_run_before_exception(capsys): +def test_run_before_exception(): + events = [] + @pass_instrument class PI: def __init__(self, id): self.id = id def enter_pass_ctx(self): - print(self.id + " enter_pass_ctx") + events.append(self.id + " enter_pass_ctx") def exit_pass_ctx(self): - print(self.id + " exit_pass_ctx") + events.append(self.id + " exit_pass_ctx") def should_run(self, mod, info): - print(self.id + " should_run") + events.append(self.id + " should_run") return True def run_before_pass(self, mod, info): - print(self.id + " run_before_pass") + events.append(self.id + " run_before_pass") raise RuntimeError("Just a dummy error") def run_after_pass(self, mod, info): - print(self.id + " run_after_pass") + events.append(self.id + " run_after_pass") @tvm.transform.module_pass(opt_level=2) def transform(mod, ctx): - print("transform pass") + events.append("transform pass") return mod mod = get_test_model() - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) assert ( - "%1 enter_pass_ctx\n" - "%2 enter_pass_ctx\n" - "%1 should_run\n" - "%2 should_run\n" - "%1 run_before_pass\n" - "%1 exit_pass_ctx\n" - "%2 exit_pass_ctx\n" == capsys.readouterr().out + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%2 should_run" + "%1 run_before_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) ) -def test_run_after_exception(capsys): +def test_run_after_exception(): + events = [] + @pass_instrument class PI: def __init__(self, id): self.id = id def enter_pass_ctx(self): - print(self.id + " enter_pass_ctx") + events.append(self.id + " enter_pass_ctx") def exit_pass_ctx(self): - print(self.id + " exit_pass_ctx") + events.append(self.id + " exit_pass_ctx") def should_run(self, mod, info): - print(self.id + " should_run") + events.append(self.id + " should_run") return True def run_before_pass(self, mod, info): - print(self.id + " run_before_pass") + events.append(self.id + " run_before_pass") def run_after_pass(self, mod, info): - print(self.id + " run_after_pass") + events.append(self.id + " run_after_pass") raise RuntimeError("Just a dummy error") @tvm.transform.module_pass(opt_level=2) def transform(mod, ctx): - print("transform pass") + events.append("transform pass") return mod x, y = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xy"] mod = tvm.IRModule.from_expr(tvm.relay.add(x, y)) - with pytest.raises(tvm.error.TVMError): + with pytest.raises(tvm.error.TVMError) as cm: with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) assert ( - "%1 enter_pass_ctx\n" - "%2 enter_pass_ctx\n" - "%1 should_run\n" - "%2 should_run\n" - "%1 run_before_pass\n" - "%2 run_before_pass\n" - "transform pass\n" - "%1 run_after_pass\n" - "%1 exit_pass_ctx\n" - "%2 exit_pass_ctx\n" == capsys.readouterr().out + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%2 should_run" + "%1 run_before_pass" + "%2 run_before_pass" + "transform pass" + "%1 run_after_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) ) -def test_instrument_call_sequence(capsys): +def test_instrument_call_sequence(): + events = [] + @pass_instrument class PI: def __init__(self, id): self.id = id def enter_pass_ctx(self): - print(self.id + " enter_pass_ctx") + events.append(self.id + " enter_pass_ctx") def exit_pass_ctx(self): - print(self.id + " exit_pass_ctx") + events.append(self.id + " exit_pass_ctx") def should_run(self, mod, info): - print(" " + self.id + " should_run") + events.append(" " + self.id + " should_run") return True def run_before_pass(self, mod, info): - print(" " + self.id + " run_before_pass") + events.append(" " + self.id + " run_before_pass") def run_after_pass(self, mod, info): - print(" " + self.id + " run_after_pass") + events.append(" " + self.id + " run_after_pass") @tvm.transform.module_pass(opt_level=2) def transform1(mod, ctx): - print(" transform1 pass") + events.append(" transform1 pass") return mod @tvm.transform.module_pass(opt_level=2) def transform2(mod, ctx): - print(" transform2 pass") + events.append(" transform2 pass") return mod mod = get_test_model() @@ -476,22 +509,22 @@ def transform2(mod, ctx): mod = transform2(mod) assert ( - "%1 enter_pass_ctx\n" - "%2 enter_pass_ctx\n" - " %1 should_run\n" - " %2 should_run\n" - " %1 run_before_pass\n" - " %2 run_before_pass\n" - " transform1 pass\n" - " %1 run_after_pass\n" - " %2 run_after_pass\n" - " %1 should_run\n" - " %2 should_run\n" - " %1 run_before_pass\n" - " %2 run_before_pass\n" - " transform2 pass\n" - " %1 run_after_pass\n" - " %2 run_after_pass\n" - "%1 exit_pass_ctx\n" - "%2 exit_pass_ctx\n" == capsys.readouterr().out + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + " %1 should_run" + " %2 should_run" + " %1 run_before_pass" + " %2 run_before_pass" + " transform1 pass" + " %1 run_after_pass" + " %2 run_after_pass" + " %1 should_run" + " %2 should_run" + " %1 run_before_pass" + " %2 run_before_pass" + " transform2 pass" + " %1 run_after_pass" + " %2 run_after_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) ) From 3cbdcb053d810616cda403f3adbd63c736e375cd Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Wed, 26 May 2021 01:46:56 +0800 Subject: [PATCH 23/24] Fix lint --- src/ir/transform.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index a8f6ced5bdd6..98355cbc0c0f 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -207,7 +207,7 @@ bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo bool should_run = true; if (!pass_required) { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - should_run &= pi->ShouldRun(ir_module, pass_info); + should_run &= pi->ShouldRun(ir_module, pass_info); } } From 291c03cea659631283fa872d3aed38000889dc97 Mon Sep 17 00:00:00 2001 From: Zack Chen Date: Thu, 27 May 2021 06:48:27 +0800 Subject: [PATCH 24/24] Fix EnterContext exception behavior --- src/ir/transform.cc | 10 ++++++++++ tests/python/relay/test_pass_instrument.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 98355cbc0c0f..7760334af44c 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -169,14 +169,24 @@ PassContext PassContext::Create() { return PassContext(make_objectoperator->(); if (pass_ctx_node->instruments.defined()) { + Array enter_successes; try { for (instrument::PassInstrument pi : pass_ctx_node->instruments) { pi->EnterPassContext(); + enter_successes.push_back(pi); } } catch (const Error& e) { LOG(INFO) << "Pass instrumentation entering pass context failed."; LOG(INFO) << "Disable pass instrumentation."; pass_ctx_node->instruments.clear(); + + for (instrument::PassInstrument pi : enter_successes) { + LOG(INFO) << pi->name << " exiting PassContext ..."; + pi->ExitPassContext(); + LOG(INFO) << pi->name << " exited PassContext."; + } + enter_successes.clear(); + throw e; } } diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 740075d7a789..86283fd31819 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -211,7 +211,7 @@ def enter_pass_ctx(self): pass assert "Just a dummy error" in str(cm.execption) - assert "%1 enter ctx" "%2 enter ctx" == "".join(events) + assert "%1 enter ctx" "%2 enter ctx" "%1 exit ctx" == "".join(events) # Make sure we get correct PassContext cur_pass_ctx = tvm.transform.PassContext.current()