diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h new file mode 100644 index 000000000000..1b0e9a9ea50e --- /dev/null +++ b/include/tvm/ir/instrument.h @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/instrument.h + * + * This file introduces a pass instrument infrastructure, inspired by LLVM and MLIR. + * It inserts instrumentation points around passes. + */ +#ifndef TVM_IR_INSTRUMENT_H_ +#define TVM_IR_INSTRUMENT_H_ + +#include +#include + +#include +#include + +namespace tvm { + +class IRModule; + +// Forward class for PassInstrumentNode methods +namespace transform { +class PassInfo; +} // namespace transform + +namespace instrument { + +/*! + * \brief PassInstrumentNode forms an instrument implementation. + * It provides API for users to register callbacks at different instrumentation points. + * + * Within a PassContext, call sequence of a PassInstrument implementation is like: + * + * with PassContext(instruments=[pi]): # pi = a PassInstrument implementation + * pi.EnterPassContext() + * + * if pi.ShouldRun(Pass1): + * pi.RunBeforePass() + * Pass1() + * pi.RunAfterPass() + * + * if pi.ShouldRun(Pass2): + * pi.RunBeforePass() + * Pass2() + * pi.RunAfterPass() + * + * pi.ExitPassContext() + * + * `EnterPassContext` and `ExitPassContext` are only called once when entering/exiting a + * PassContext. `ShouldRun`, `RunBeforePass` and `RunAfterPass` are called multiple times depending + * on how many passes. + * + * If there are multiple pass instrumentations provided, the instrument points are the same. + * PassInstrument implementations' callbacks are called in order: + * + * with PassContext(instruments=[pi1, pi2]): # pi1, pi2 = two distinct PassInstrument impls + * pi.EnterPassContext() for pi in instruments + * + * should_run = all([pi.ShoudRun(Pass1) for pi in instruments)]) + * if (should_run) + * pi.RunBeforePass() for pi in instruments + * Pass1() + * pi.RunAfterPass() for pi in instruments + * + * should_run = all([pi.ShouldRun(Pass2) for pi in instruments)]) + * if (should_run) + * pi.RunBeforePass() for pi in instruments + * Pass2() + * pi.RunAfterPass() for pi in instruments + * + * pi.ExitPassContext() for pi in instruments + * + * Note: + * 1. Assume there is no dependency between PassInstrument implementations in `instruments` . + * 2. `EnterPassContext` and `ExitPassContext` have `with` behavior (see PassContext and its FFI): + * If there is any exception raised in `ShouldRun()`, `RunBeforePass()`, `RunAfterPass()` and + * `Pass()`, `ExitPassContext()` is still called. + * 3. In mutiple PassInstrument instances scenario, callbacks are called in order: + * If one throws exceptions, remainings will not be called. + * + * \sa PassInstrument + * \sa src/ir/transform.cc + */ +class PassInstrumentNode : public Object { + public: + /*! \brief Name of this pass instrument object. */ + String name; + + virtual ~PassInstrumentNode() {} + + /*! \brief Instrument when entering PassContext. Called once within a PassContext. */ + virtual void EnterPassContext() const = 0; + + /*! \brief Instrument when exiting PassContext. Called once within a PassContext. */ + virtual void ExitPassContext() const = 0; + + /*! + * \brief Determine whether to run the pass or not. Called multiple times depend on number of + * passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; + + /*! + * \brief Instrument before pass run. Called multiple times depend on number of passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + + /*! + * \brief Instrument after pass run. Called multiple time depend on number of passes. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; + + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + + static constexpr const char* _type_key = "instrument.PassInstrument"; + TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); +}; + +/*! + * \brief Managed reference class for PassInstrumentNode + * \sa PassInstrumentNode + */ +class PassInstrument : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); +}; + +} // namespace instrument +} // namespace tvm + +#endif // TVM_IR_INSTRUMENT_H_ diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 50c6f8dd8c3a..849eda6cd248 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -58,6 +58,7 @@ #include #include +#include #include #include #include @@ -68,15 +69,6 @@ namespace tvm { namespace transform { -// Forward declare for TraceFunc. -class PassInfo; - -/*! - * \brief A callback for tracing passes, useful for debugging and logging. - */ -using TraceFunc = - runtime::TypedPackedFunc; - /*! * \brief PassContextNode contains the information that a pass can rely on, * such as analysis results. @@ -95,8 +87,9 @@ class PassContextNode : public Object { mutable Optional diag_ctx; /*! \brief Pass specific configurations. */ Map config; - /*! \brief Trace function to be invoked before and after each pass. */ - TraceFunc trace_func; + + /*! \brief A list of pass instrument implementations. */ + Array instruments; PassContextNode() = default; @@ -134,6 +127,7 @@ class PassContextNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); } @@ -189,12 +183,40 @@ class PassContext : public ObjectRef { TVM_DLL static PassContext Current(); /*! - * \brief Apply the tracing functions of the context to the module, with the info. - * \param module The IRModule to trace. + * \brief Call instrument implementations' callbacks when entering PassContext. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + */ + TVM_DLL void InstrumentEnterPassContext(); + + /*! + * \brief Call instrument implementations' callbacks when exiting PassContext. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + */ + TVM_DLL void InstrumentExitPassContext(); + + /*! + * \brief Call instrument implementations' callbacks before a pass run. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + * + * \param mod The module that an optimization pass runs on. * \param info The pass information. - * \param is_before Indicated whether the tracing is before or after a pass. + * + * \return false: the pass is skipped; true: the pass runs. */ - TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + + /*! + * \brief Call instrument implementations callbacks after a pass run. + * The callbacks are called in order, and if one raises an exception, the rest will not be + * called. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; /*! * \brief Check whether a pass is enabled. @@ -275,7 +297,7 @@ class PassInfoNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); }; -/* +/*! * \brief Managed reference class for PassInfoNode * \sa PassInfoNode */ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 0adad82d9bec..77630730f03a 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -40,6 +40,7 @@ # tvm.ir from .ir import IRModule from .ir import transform +from .ir import instrument from .ir import container from . import ir @@ -67,7 +68,6 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel - # NOTE: This file should be python2 compatible so we can # raise proper error message when user run the package using # an older version of the python diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4bc7f1ae4468..70c5988d6316 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -31,4 +31,5 @@ from .container import Array, Map from . import transform +from . import instrument from . import diagnostics diff --git a/tests/python/relay/test_pass_profiler.py b/python/tvm/ir/_ffi_instrument_api.py similarity index 51% rename from tests/python/relay/test_pass_profiler.py rename to python/tvm/ir/_ffi_instrument_api.py index acf6c8c50aff..bf62caf30e5a 100644 --- a/tests/python/relay/test_pass_profiler.py +++ b/python/tvm/ir/_ffi_instrument_api.py @@ -14,28 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -import tvm.relay -from tvm.relay import op +"""FFI APIs for tvm.instrument""" +import tvm._ffi - -def test_pass_profiler(): - x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] - e1 = op.add(x, y) - e2 = op.subtract(x, z) - e3 = op.multiply(e1, e1 / e2) - mod = tvm.IRModule.from_expr(e3 + e2) - - tvm.transform.enable_pass_profiling() - - mod = tvm.relay.transform.AnnotateSpans()(mod) - mod = tvm.relay.transform.ToANormalForm()(mod) - mod = tvm.relay.transform.InferType()(mod) - - profiles = tvm.transform.render_pass_profiles() - assert "AnnotateSpans" in profiles - assert "ToANormalForm" in profiles - assert "InferType" in profiles - - tvm.transform.clear_pass_profiles() - tvm.transform.disable_pass_profiling() +tvm._ffi._init_api("instrument", __name__) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py new file mode 100644 index 000000000000..c322f2bef3fc --- /dev/null +++ b/python/tvm/ir/instrument.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Common pass instrumentation across IR variants.""" +import inspect +import functools + +import tvm._ffi +import tvm.runtime + +from . import _ffi_instrument_api + + +@tvm._ffi.register_object("instrument.PassInstrument") +class PassInstrument(tvm.runtime.Object): + """A pass instrument implementation. + + Users don't need to interact with this class directly. + Instead, a `PassInstrument` instance should be created through `pass_instrument`. + + See Also + -------- + `pass_instrument` + """ + + +def _wrap_class_pass_instrument(pi_cls): + """Wrap a python class as pass instrument""" + + class PyPassInstrument(PassInstrument): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pi_cls creation failed. + self.handle = None + inst = pi_cls(*args, **kwargs) + + # check method declartion within class, if found, wrap it. + def create_method(method): + if hasattr(inst, method) and inspect.ismethod(getattr(inst, method)): + + def func(*args): + return getattr(inst, method)(*args) + + func.__name__ = "_" + method + return func + return None + + # create runtime pass instrument object + # reister instance's enter_pass_ctx,exit_pass_ctx, should_run, run_before_pass and + # run_after_pass methods to it if present. + self.__init_handle_by_constructor__( + _ffi_instrument_api.PassInstrument, + pi_cls.__name__, + create_method("enter_pass_ctx"), + create_method("exit_pass_ctx"), + create_method("should_run"), + create_method("run_before_pass"), + create_method("run_after_pass"), + ) + + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyPassInstrument.__init__, pi_cls.__init__) + PyPassInstrument.__name__ = pi_cls.__name__ + PyPassInstrument.__doc__ = pi_cls.__doc__ + PyPassInstrument.__module__ = pi_cls.__module__ + return PyPassInstrument + + +def pass_instrument(pi_cls=None): + """Decorate a pass instrument. + + Parameters + ---------- + pi_class : + + Examples + -------- + The following code block decorates a pass instrument class. + + .. code-block:: python + @tvm.instrument.pass_instrument + class SkipPass: + def __init__(self, skip_pass_name): + self.skip_pass_name = skip_pass_name + + # Uncomment to customize + # def enter_pass_ctx(self): + # pass + + # Uncomment to customize + # def exit_pass_ctx(self): + # pass + + # If pass name contains keyword, skip it by return False. (return True: not skip) + def should_run(self, mod, pass_info) + if self.skip_pass_name in pass_info.name: + return False + return True + + # Uncomment to customize + # def run_before_pass(self, mod, pass_info): + # pass + + # Uncomment to customize + # def run_after_pass(self, mod, pass_info): + # pass + + skip_annotate = SkipPass("AnnotateSpans") + with tvm.transform.PassContext(instruments=[skip_annotate]): + tvm.relay.build(mod, "llvm") + """ + + def create_pass_instrument(pi_cls): + if not inspect.isclass(pi_cls): + raise TypeError("pi_cls must be a class") + + return _wrap_class_pass_instrument(pi_cls) + + if pi_cls: + return create_pass_instrument(pi_cls) + return create_pass_instrument + + +@tvm._ffi.register_object("instrument.PassInstrument") +class PassTimingInstrument(tvm.runtime.Object): + """A wrapper to create a passes time instrument that implemented in C++""" + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) + + @staticmethod + def render(): + """Retrieve rendered time profile result + Returns + ------- + string : string + The rendered string result of time profiles + """ + return _ffi_instrument_api.RenderTimePassProfiles() diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 36e06eeb8b23..3a3ac16be677 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -65,12 +65,20 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + instruments : Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. + config : Optional[Dict[str, Object]] Additional configurations for specific passes. """ def __init__( - self, opt_level=2, required_pass=None, disabled_pass=None, trace=None, config=None + self, + opt_level=2, + required_pass=None, + disabled_pass=None, + instruments=None, + config=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -80,9 +88,13 @@ def __init__( if not isinstance(disabled, (list, tuple)): raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") + instruments = list(instruments) if instruments else [] + if not isinstance(instruments, (list, tuple)): + raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, trace, config + _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config ) def __enter__(self): @@ -92,6 +104,17 @@ def __enter__(self): def __exit__(self, ptype, value, trace): _ffi_transform_api.ExitPassContext(self) + def override_instruments(self, instruments): + """Override instruments within this PassContext. + + If there are existing instruments, their exit_pass_ctx callbacks are called. + Then switching to new instruments and calling new enter_pass_ctx callbacks. + + instruments : Sequence[PassInstrument] + The list of pass instrument implementations. + """ + _ffi_transform_api.OverrideInstruments(self, instruments) + @staticmethod def current(): """Return the current pass context.""" @@ -189,6 +212,7 @@ def __init__(self, *args, **kwargs): # initialize handle in cass pass_cls creation failed.fg self.handle = None inst = pass_cls(*args, **kwargs) + # it is important not to capture self to # avoid a cyclic dependency def _pass_func(mod, ctx): @@ -330,26 +354,3 @@ def PrintIR(header="", show_meta_data=False): The pass """ return _ffi_transform_api.PrintIR(header, show_meta_data) - - -def render_pass_profiles(): - """Returns a string render of the pass profiling data. The format of each output line is - `{name}: {time} [{time excluding sub-passes}] ({% of total}; {% of parent})`. - The indentation of each line corresponds to nesting of passes. - """ - return _ffi_transform_api.render_pass_profiles() - - -def clear_pass_profiles(): - """Clears all stored pass profiling data.""" - _ffi_transform_api.clear_pass_profiles() - - -def enable_pass_profiling(): - """Enables pass profiling.""" - _ffi_transform_api.enable_pass_profiling() - - -def disable_pass_profiling(): - """Disables pass profiling.""" - _ffi_transform_api.disable_pass_profiling() diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc new file mode 100644 index 000000000000..795e5b8cb542 --- /dev/null +++ b/src/ir/instrument.cc @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/instrument.cc + * \brief Infrastructure for instrumentation. + */ +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace instrument { + +/*! + * \brief Base PassInstrument implementation + * \sa BasePassInstrument + */ +class BasePassInstrumentNode : public PassInstrumentNode { + public: + /*! \brief Callback to run when entering PassContext. */ + runtime::TypedPackedFunc enter_pass_ctx_callback; + /*! \brief Callback to run when exiting PassContext. */ + runtime::TypedPackedFunc exit_pass_ctx_callback; + + /*! \brief Callback determines whether to run a pass or not. */ + runtime::TypedPackedFunc should_run_callback; + + /*! \brief Callback to run before a pass. */ + runtime::TypedPackedFunc + run_before_pass_callback; + /*! \brief Callback to run after a pass. */ + runtime::TypedPackedFunc + run_after_pass_callback; + + /*! \brief Instrument when entering PassContext. */ + void EnterPassContext() const final; + + /*! \brief Instrument when exiting PassContext. */ + void ExitPassContext() const final; + + /*! + * \brief Determine whether to run the pass or not. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + * + * \return true to run the pass; false to skip the pass. + */ + bool ShouldRun(const IRModule&, const transform::PassInfo& info) const final; + + /*! + * \brief Instrument before pass run. + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final; + + /*! + * \brief Instrument after pass run. + * + * \param mod The module that an optimization pass runs on. + * \param info The pass information. + */ + void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final; + + static constexpr const char* _type_key = "instrument.PassInstrument"; + TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode); +}; + +/*! + * \brief Managed reference class for BasePassInstrumentNode + * \sa BasePassInstrumentNode + */ +class BasePassInstrument : public PassInstrument { + public: + /*! + * \brief Constructor + * + * \param name Name for this instrumentation. + * + * + * \param enter_pass_ctx_callback Callback to call when entering pass context. + * \param exit_pass_ctx_callback Callback to call when exiting pass context. + * + * \param should_run_callback Callback to determine whether pass should run. (return true: enable; + * return false: disable) + * + * \param run_before_pass_callback Callback to call before a pass run. + * \param run_after_pass_callback Callback to call after a pass run. + */ + TVM_DLL BasePassInstrument( + String name, runtime::TypedPackedFunc enter_pass_ctx_callback, + runtime::TypedPackedFunc exit_pass_ctx_callback, + runtime::TypedPackedFunc + should_run_callback, + runtime::TypedPackedFunc + run_before_pass_callback, + runtime::TypedPackedFunc + run_after_pass_callback); + + TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode); +}; + +BasePassInstrument::BasePassInstrument( + String name, runtime::TypedPackedFunc enter_pass_ctx_callback, + runtime::TypedPackedFunc exit_pass_ctx_callback, + runtime::TypedPackedFunc should_run_callback, + runtime::TypedPackedFunc + run_before_pass_callback, + runtime::TypedPackedFunc + run_after_pass_callback) { + auto pi = make_object(); + pi->name = std::move(name); + + pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback); + pi->exit_pass_ctx_callback = std::move(exit_pass_ctx_callback); + + pi->should_run_callback = std::move(should_run_callback); + + pi->run_before_pass_callback = std::move(run_before_pass_callback); + pi->run_after_pass_callback = std::move(run_after_pass_callback); + + data_ = std::move(pi); +} + +void BasePassInstrumentNode::EnterPassContext() const { + if (enter_pass_ctx_callback != nullptr) { + enter_pass_ctx_callback(); + } +} + +void BasePassInstrumentNode::ExitPassContext() const { + if (exit_pass_ctx_callback != nullptr) { + exit_pass_ctx_callback(); + } +} + +bool BasePassInstrumentNode::ShouldRun(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (should_run_callback == nullptr) { + return true; + } + + return should_run_callback(ir_module, pass_info); +} + +void BasePassInstrumentNode::RunBeforePass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (run_before_pass_callback != nullptr) { + run_before_pass_callback(ir_module, pass_info); + } +} + +void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, + const transform::PassInfo& pass_info) const { + if (run_after_pass_callback != nullptr) { + run_after_pass_callback(ir_module, pass_info); + } +} + +TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); + +TVM_REGISTER_GLOBAL("instrument.PassInstrument") + .set_body_typed( + [](String name, runtime::TypedPackedFunc enter_pass_ctx, + runtime::TypedPackedFunc exit_pass_ctx, + runtime::TypedPackedFunc should_run, + runtime::TypedPackedFunc + run_before_pass, + runtime::TypedPackedFunc + run_after_pass) { + return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, + run_before_pass, run_after_pass); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->name; + }); + +/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ +struct PassProfile { + // TODO(@altanh): expose PassProfile through TVM Object API + using Clock = std::chrono::steady_clock; + using Duration = std::chrono::duration; + using Time = std::chrono::time_point; + + /*! \brief The name of the pass being profiled. */ + String name; + /*! \brief The time when the pass was entered. */ + Time start; + /*! \brief The time when the pass completed. */ + Time end; + /*! \brief The total duration of the pass, i.e. end - start. */ + Duration duration; + /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ + std::vector children; + + explicit PassProfile(String name) + : name(name), start(Clock::now()), end(Clock::now()), children() {} + + /*! \brief Gets the PassProfile of the currently executing pass. */ + static PassProfile* Current(); + /*! \brief Pushes a new PassProfile with the given pass name. */ + static void EnterPass(String name); + /*! \brief Pops the current PassProfile. */ + static void ExitPass(); +}; + +struct PassProfileThreadLocalEntry { + /*! \brief The placeholder top-level PassProfile. */ + PassProfile root; + /*! \brief The stack of PassProfiles for nested passes currently running. */ + std::stack profile_stack; + + PassProfileThreadLocalEntry() : root("root") {} +}; + +/*! \brief Thread local store to hold the pass profiling data. */ +typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; + +void PassProfile::EnterPass(String name) { + PassProfile* cur = PassProfile::Current(); + cur->children.emplace_back(name); + PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +} + +void PassProfile::ExitPass() { + PassProfile* cur = PassProfile::Current(); + ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; + cur->end = PassProfile::Clock::now(); + cur->duration = std::chrono::duration_cast(cur->end - cur->start); + PassProfileThreadLocalStore::Get()->profile_stack.pop(); +} + +PassProfile* PassProfile::Current() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + if (!entry->profile_stack.empty()) { + return entry->profile_stack.top(); + } else { + return &entry->root; + } +} + +String RenderPassProfiles() { + PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); + CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; + + if (entry->root.children.empty()) { + LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; + return String(); + } + + // (depth, parent_duration, pass) + std::stack> profiles; + + // push top level passes + PassProfile::Duration top_dur(0); + for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { + top_dur += it->duration; + } + for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { + profiles.push(std::make_tuple(0, top_dur, &*it)); + } + + std::ostringstream os; + os << std::fixed; + + while (profiles.size() > 0) { + size_t depth; + PassProfile::Duration parent_duration; + PassProfile* profile; + std::tie(depth, parent_duration, profile) = profiles.top(); + profiles.pop(); + + // indent depth + for (size_t i = 0; i < depth; ++i) { + os << "\t"; + } + + // calculate time spent in pass itself (excluding sub-passes), and push children + PassProfile::Duration self_duration = profile->duration; + for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { + self_duration -= it->duration; + profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); + } + + double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; + double total_pct = profile->duration.count() / top_dur.count() * 100.0; + + os << profile->name << ": "; + os << std::setprecision(0); + os << profile->duration.count() << "us [" << self_duration.count() << "us] "; + os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; + } + + return os.str(); +} + +TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); + +TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { + auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::EnterPass(pass_info->name); + return true; + }; + + auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) { + PassProfile::ExitPass(); + }; + + auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; + + return BasePassInstrument("PassTimingInstrument", + /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr, + run_before_pass, run_after_pass); +}); + +} // namespace instrument +} // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 48f13bc81df4..7760334af44c 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -56,6 +56,8 @@ struct PassContextThreadLocalEntry { typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { + InstrumentEnterPassContext(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } @@ -65,6 +67,8 @@ void PassContext::ExitWithScope() { ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); + + InstrumentExitPassContext(); } PassContext PassContext::Current() { @@ -162,170 +166,96 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassContext PassContext::Create() { return PassContext(make_object()); } -void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { +void PassContext::InstrumentEnterPassContext() { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); + if (pass_ctx_node->instruments.defined()) { + Array enter_successes; + try { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->EnterPassContext(); + enter_successes.push_back(pi); + } + } catch (const Error& e) { + LOG(INFO) << "Pass instrumentation entering pass context failed."; + LOG(INFO) << "Disable pass instrumentation."; + pass_ctx_node->instruments.clear(); + + for (instrument::PassInstrument pi : enter_successes) { + LOG(INFO) << pi->name << " exiting PassContext ..."; + pi->ExitPassContext(); + LOG(INFO) << pi->name << " exited PassContext."; + } + enter_successes.clear(); + + throw e; + } } } -class ModulePass; - -/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ -struct PassProfile { - // TODO(@altanh): expose PassProfile through TVM Object API - using Clock = std::chrono::steady_clock; - using Duration = std::chrono::duration; - using Time = std::chrono::time_point; - - /*! \brief The name of the pass being profiled. */ - String name; - /*! \brief The time when the pass was entered. */ - Time start; - /*! \brief The time when the pass completed. */ - Time end; - /*! \brief The total duration of the pass, i.e. end - start. */ - Duration duration; - /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ - std::vector children; - - explicit PassProfile(String name) - : name(name), start(Clock::now()), end(Clock::now()), children() {} - - /*! \brief Gets the PassProfile of the currently executing pass. */ - static PassProfile* Current(); - /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); - /*! \brief Pops the current PassProfile. */ - static void ExitPass(); -}; - -struct PassProfileThreadLocalEntry { - /*! \brief The placeholder top-level PassProfile. */ - PassProfile root; - /*! \brief The stack of PassProfiles for nested passes currently running. */ - std::stack profile_stack; - /*! \brief Whether or not pass profiling is active. */ - bool active; - - PassProfileThreadLocalEntry() : root("root"), active(false) {} -}; +void PassContext::InstrumentExitPassContext() { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + try { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->ExitPassContext(); + } + } catch (const Error& e) { + LOG(INFO) << "Pass instrumentation exiting pass context failed."; + pass_ctx_node->instruments.clear(); + throw e; + } + } +} -/*! \brief Thread local store to hold the pass profiling data. */ -typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; +bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (!pass_ctx_node->instruments.defined()) { + return true; + } -void PassProfile::EnterPass(String name) { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - cur->children.emplace_back(name); - PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); -} + const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, pass_info->name); + bool should_run = true; + if (!pass_required) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + should_run &= pi->ShouldRun(ir_module, pass_info); + } + } -void PassProfile::ExitPass() { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; - cur->end = std::move(PassProfile::Clock::now()); - cur->duration = std::chrono::duration_cast(cur->end - cur->start); - PassProfileThreadLocalStore::Get()->profile_stack.pop(); + if (should_run) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunBeforePass(ir_module, pass_info); + } + } + return should_run; } -PassProfile* PassProfile::Current() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - if (!entry->profile_stack.empty()) { - return entry->profile_stack.top(); - } else { - return &entry->root; +void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunAfterPass(ir_module, pass_info); + } } } IRModule Pass::operator()(IRModule mod) const { - const PassNode* node = operator->(); - ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); - auto ret = node->operator()(std::move(mod)); - PassProfile::ExitPass(); - return std::move(ret); + return this->operator()(std::move(mod), PassContext::Current()); } IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); + const PassInfo& pass_info = node->Info(); + if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { + DLOG(INFO) << "Skipping pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; + return mod; + } auto ret = node->operator()(std::move(mod), pass_ctx); - PassProfile::ExitPass(); + pass_ctx.InstrumentAfterPass(ret, pass_info); return std::move(ret); } -String RenderPassProfiles() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; - - if (entry->root.children.empty()) { - LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; - return String(); - } - - // (depth, parent_duration, pass) - std::stack> profiles; - - // push top level passes - PassProfile::Duration top_dur(0); - for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { - top_dur += it->duration; - } - for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { - profiles.push(std::make_tuple(0, top_dur, &*it)); - } - - std::ostringstream os; - os << std::fixed; - - while (profiles.size() > 0) { - size_t depth; - PassProfile::Duration parent_duration; - PassProfile* profile; - std::tie(depth, parent_duration, profile) = profiles.top(); - profiles.pop(); - - // indent depth - for (size_t i = 0; i < depth; ++i) { - os << "\t"; - } - - // calculate time spent in pass itself (excluding sub-passes), and push children - PassProfile::Duration self_duration = profile->duration; - for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { - self_duration -= it->duration; - profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); - } - - double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; - double total_pct = profile->duration.count() / top_dur.count() * 100.0; - - os << profile->name << ": "; - os << std::setprecision(0); - os << profile->duration.count() << "us [" << self_duration.count() << "us] "; - os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n"; - } - - return os.str(); -} - -TVM_REGISTER_GLOBAL("transform.render_pass_profiles").set_body_typed(RenderPassProfiles); - -TVM_REGISTER_GLOBAL("transform.clear_pass_profiles").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->root.children.clear(); -}); - -TVM_REGISTER_GLOBAL("transform.enable_pass_profiling").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->active = true; -}); - -TVM_REGISTER_GLOBAL("transform.disable_pass_profiling").set_body_typed([]() { - PassProfileThreadLocalStore::Get()->active = false; -}); - /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes @@ -464,12 +394,11 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c << "The diagnostic context was set at the top of this block this is a bug."; const PassInfo& pass_info = Info(); + ICHECK(mod.defined()) << "The input module must be set."; + DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - ICHECK(mod.defined()) << "The input module must be set."; - - pass_ctx.Trace(mod, pass_info, true); mod = pass_func(std::move(mod), pass_ctx); ICHECK(mod.defined()) << "The return value of a module pass must be set."; @@ -480,7 +409,6 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.Trace(mod, pass_info, false); return mod; } @@ -621,13 +549,14 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, - TraceFunc trace_func, Optional> config) { + Array instruments, + Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); - pctx->trace_func = std::move(trace_func); + pctx->instruments = std::move(instruments); if (config.defined()) { pctx->config = config.value(); } @@ -642,17 +571,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\trequired passes: ["; - for (const auto& it : node->required_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; + p->stream << "\trequired passes: " << node->required_pass << "\n"; + p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; + p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tdisabled passes: ["; - for (const auto& it : node->disabled_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; p->stream << "\tconfig: " << node->config; }); @@ -669,6 +591,13 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("transform.OverrideInstruments") + .set_body_typed([](PassContext pass_ctx, Array instruments) { + pass_ctx.InstrumentExitPassContext(); + pass_ctx->instruments = instruments; + pass_ctx.InstrumentEnterPassContext(); + }); + Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 596f812e25af..4a7974cae5ae 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -133,8 +133,6 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - pass_ctx.Trace(mod, pass_info, true); - // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); @@ -159,8 +157,6 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; - pass_ctx.Trace(updated_mod, pass_info, false); - // TODO(@jroesch): move away from eager type checking for performance reasons // make issue. return transform::InferType()(updated_mod); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 95c40f9a3c8e..4c59a1767372 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -87,9 +87,7 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - const PassInfo& pass_info = Info(); ICHECK(mod.defined()); - pass_ctx.Trace(mod, pass_info, true); std::vector deleted_list; IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); @@ -112,7 +110,6 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& gv : deleted_list) { func_dict->erase(gv); } - pass_ctx.Trace(mod, pass_info, false); return mod; } diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index abf795cd46cc..6f68cba268cf 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -23,12 +23,19 @@ from tvm.contrib import graph_executor from tvm.relay.expr_functor import ExprMutator from tvm.relay import transform +from tvm.ir.instrument import pass_instrument import tvm.testing -def _trace(module, metadata, _): - if metadata.name == "ManifestAlloc": - pass # import pdb; pdb.set_trace() +@tvm.instrument.pass_instrument +class Trace: + def run_before_pass(self, module, pass_info): + if pass_info.name == "ManifestAlloc": + pass # import pdb; pdb.set_trace() + + def run_after_pass(self, module, pass_info): + if pass_info.name == "ManifestAlloc": + pass # import pdb; pdb.set_trace() def check_graph_executor( @@ -49,7 +56,7 @@ def check_graph_executor( def check_vm_runtime(target, ref_res, device, func, params, config, opt_level, expected_index=None): - with tvm.transform.PassContext(opt_level=opt_level, trace=_trace, config=config): + with tvm.transform.PassContext(opt_level=opt_level, instruments=[Trace()], config=config): mod = tvm.IRModule() mod["main"] = func exe = relay.vm.compile(mod, target) @@ -186,7 +193,7 @@ def check_annotated_graph(annotated_func, expected_func): def test_conv_network(): - R"""The network is as following: + r"""The network is as following: data1 data2 | | conv2d conv2d @@ -389,7 +396,7 @@ def get_func(): return func def test_fuse_log_add(device, tgt): - """ Only log and add are fused.""" + """Only log and add are fused.""" fallback_device = tvm.device("cpu") target = {"cpu": "llvm", device: tgt} cpu_dev = fallback_device @@ -530,7 +537,7 @@ def test_fallback_all_operators(device, tgt): def run_unpropagatable_graph(dev, tgt): - R"""The network is as following: + r"""The network is as following: a b c d \ / \ / add mul diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py new file mode 100644 index 000000000000..86283fd31819 --- /dev/null +++ b/tests/python/relay/test_pass_instrument.py @@ -0,0 +1,530 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Instrument test cases. +""" +import pytest +import tvm +import tvm.relay +from tvm.relay import op +from tvm.ir.instrument import PassTimingInstrument, pass_instrument + + +def get_test_model(): + x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"] + e1 = op.add(x, y) + e2 = op.subtract(x, z) + e3 = op.multiply(e1, e1 / e2) + return tvm.IRModule.from_expr(e3 + e2) + + +def test_pass_timing_instrument(): + pass_timing = PassTimingInstrument() + + # Override current PassContext's instruments + tvm.transform.PassContext.current().override_instruments([pass_timing]) + + mod = get_test_model() + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = pass_timing.render() + assert "AnnotateSpans" in profiles + assert "ToANormalForm" in profiles + assert "InferType" in profiles + + # Reset current PassContext's instruments to None + tvm.transform.PassContext.current().override_instruments(None) + + mod = get_test_model() + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + profiles = pass_timing.render() + assert profiles == "" + + +def test_custom_instrument(): + @pass_instrument + class MyTest: + def __init__(self): + self.events = [] + + def enter_pass_ctx(self): + self.events.append("enter ctx") + + def exit_pass_ctx(self): + self.events.append("exit ctx") + + def run_before_pass(self, mod, info): + self.events.append("run before " + info.name) + + def run_after_pass(self, mod, info): + self.events.append("run after " + info.name) + + mod = get_test_model() + my_test = MyTest() + with tvm.transform.PassContext(instruments=[my_test]): + mod = tvm.relay.transform.InferType()(mod) + + assert ( + "enter ctx" + "run before InferType" + "run after InferType" + "exit ctx" == "".join(my_test.events) + ) + + +def test_disable_pass(): + @pass_instrument + class CustomPI: + def __init__(self): + self.events = [] + + def should_run(self, mod, info): + # Only run pass name contains "InferType" + if "InferType" not in info.name: + return False + return True + + def run_before_pass(self, mod, info): + self.events.append(info.name) + + mod = get_test_model() + custom_pi = CustomPI() + with tvm.transform.PassContext(instruments=[custom_pi]): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + assert "InferType" == "".join(custom_pi.events) + + +def test_multiple_instrument(): + @pass_instrument + class SkipPass: + def __init__(self, skip_pass_name): + self.skip_pass_name = skip_pass_name + + def should_run(self, mod, info): + if self.skip_pass_name in info.name: + return False + return True + + skip_annotate = SkipPass("AnnotateSpans") + skip_anf = SkipPass("ToANormalForm") + + @pass_instrument + class PrintPassName: + def __init__(self): + self.events = [] + + def run_before_pass(self, mod, info): + self.events.append(info.name) + + mod = get_test_model() + print_pass_name = PrintPassName() + with tvm.transform.PassContext(instruments=[skip_annotate, skip_anf, print_pass_name]): + mod = tvm.relay.transform.AnnotateSpans()(mod) + mod = tvm.relay.transform.ToANormalForm()(mod) + mod = tvm.relay.transform.InferType()(mod) + + assert "InferType" == "".join(print_pass_name.events) + + +def test_instrument_pass_counts(): + @pass_instrument + class PassesCounter: + def __init__(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def __clear(self): + self.run_before_count = 0 + self.run_after_count = 0 + + def enter_pass_ctx(self): + self.__clear() + + def exit_pass_ctx(self): + self.__clear() + + def run_before_pass(self, mod, info): + self.run_before_count = self.run_before_count + 1 + + def run_after_pass(self, mod, info): + self.run_after_count = self.run_after_count + 1 + + mod = get_test_model() + passes_counter = PassesCounter() + with tvm.transform.PassContext(instruments=[passes_counter]): + tvm.relay.build(mod, "llvm") + assert passes_counter.run_after_count != 0 + assert passes_counter.run_after_count == passes_counter.run_before_count + + # Out of pass context scope, should be reset + assert passes_counter.run_before_count == 0 + assert passes_counter.run_after_count == 0 + + +def test_enter_pass_ctx_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + + @pass_instrument + class PIBroken(PI): + def __init__(self, id): + super().__init__(id) + + def enter_pass_ctx(self): + events.append(self.id + " enter ctx") + raise RuntimeError("Just a dummy error") + + pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) + with pytest.raises(tvm.error.TVMError) as cm: + with pass_ctx: + pass + assert "Just a dummy error" in str(cm.execption) + + assert "%1 enter ctx" "%2 enter ctx" "%1 exit ctx" == "".join(events) + + # Make sure we get correct PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + assert pass_ctx != cur_pass_ctx + assert cur_pass_ctx.instruments == None + + +def test_enter_pass_ctx_exception_global(): + @pass_instrument + class PIBroken: + def enter_pass_ctx(self): + raise RuntimeError("Just a dummy error") + + cur_pass_ctx = tvm.transform.PassContext.current() + with pytest.raises(tvm.error.TVMError) as cm: + cur_pass_ctx.override_instruments([PIBroken()]) + assert "Just a dummy error" in str(cm.exception) + assert not cur_pass_ctx.instruments + + +def test_exit_pass_ctx_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + + @pass_instrument + class PIBroken(PI): + def __init__(self, id): + super().__init__(id) + + def exit_pass_ctx(self): + events.append(self.id + " exit ctx") + raise RuntimeError("Just a dummy error") + + pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) + with pytest.raises(tvm.error.TVMError) as cm: + with pass_ctx: + pass + assert "Just a dummy error" in str(cm.exception) + + assert "%1 exit ctx" "%2 exit ctx" == "".join(events) + + # Make sure we get correct PassContext + cur_pass_ctx = tvm.transform.PassContext.current() + assert pass_ctx != cur_pass_ctx + assert not cur_pass_ctx.instruments + + +def test_exit_pass_ctx_exception_global(): + @pass_instrument + class PIBroken: + def exit_pass_ctx(self): + raise RuntimeError("Just a dummy error") + + cur_pass_ctx = tvm.transform.PassContext.current() + with pytest.raises(tvm.error.TVMError) as cm: + cur_pass_ctx.override_instruments([PIBroken()]) + cur_pass_ctx.override_instruments([PIBroken()]) + assert "Just a dummy error" in str(cm.exception) + assert not cur_pass_ctx.instruments + + +def test_pass_exception(): + events = [] + + @pass_instrument + class PI: + def enter_pass_ctx(self): + events.append("enter_pass_ctx") + + def exit_pass_ctx(self): + events.append("exit_pass_ctx") + + def should_run(self, mod, info): + events.append("should_run") + return True + + def run_before_pass(self, mod, info): + events.append("run_before_pass") + + def run_after_pass(self, mod, info): + events.append("run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + raise RuntimeError("Just a dummy error") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI()]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "enter_pass_ctx" + "should_run" + "run_before_pass" + "transform pass" + "exit_pass_ctx" == "".join(events) + ) + + +def test_should_run_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(self.id + " should_run") + raise RuntimeError("Just a dummy error") + return True + + def run_before_pass(self, mod, info): + events.append(self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + events.append(self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) + + +def test_run_before_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + events.append(self.id + " run_before_pass") + raise RuntimeError("Just a dummy error") + + def run_after_pass(self, mod, info): + events.append(self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + return mod + + mod = get_test_model() + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%2 should_run" + "%1 run_before_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) + + +def test_run_after_exception(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + events.append(self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + events.append(self.id + " run_after_pass") + raise RuntimeError("Just a dummy error") + + @tvm.transform.module_pass(opt_level=2) + def transform(mod, ctx): + events.append("transform pass") + return mod + + x, y = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xy"] + mod = tvm.IRModule.from_expr(tvm.relay.add(x, y)) + + with pytest.raises(tvm.error.TVMError) as cm: + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform(mod) + assert "Just a dummy error" in str(cm.exception) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + "%1 should_run" + "%2 should_run" + "%1 run_before_pass" + "%2 run_before_pass" + "transform pass" + "%1 run_after_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) + + +def test_instrument_call_sequence(): + events = [] + + @pass_instrument + class PI: + def __init__(self, id): + self.id = id + + def enter_pass_ctx(self): + events.append(self.id + " enter_pass_ctx") + + def exit_pass_ctx(self): + events.append(self.id + " exit_pass_ctx") + + def should_run(self, mod, info): + events.append(" " + self.id + " should_run") + return True + + def run_before_pass(self, mod, info): + events.append(" " + self.id + " run_before_pass") + + def run_after_pass(self, mod, info): + events.append(" " + self.id + " run_after_pass") + + @tvm.transform.module_pass(opt_level=2) + def transform1(mod, ctx): + events.append(" transform1 pass") + return mod + + @tvm.transform.module_pass(opt_level=2) + def transform2(mod, ctx): + events.append(" transform2 pass") + return mod + + mod = get_test_model() + with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): + mod = transform1(mod) + mod = transform2(mod) + + assert ( + "%1 enter_pass_ctx" + "%2 enter_pass_ctx" + " %1 should_run" + " %2 should_run" + " %1 run_before_pass" + " %2 run_before_pass" + " transform1 pass" + " %1 run_after_pass" + " %2 run_after_pass" + " %1 should_run" + " %2 should_run" + " %1 run_before_pass" + " %2 run_before_pass" + " transform2 pass" + " %1 run_after_pass" + " %2 run_after_pass" + "%1 exit_pass_ctx" + "%2 exit_pass_ctx" == "".join(events) + ) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 5a29d1acd171..ee889d857cee 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -25,6 +25,7 @@ from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform +from tvm.ir import instrument as _instrument from tvm.relay.testing import run_infer_type import tvm.testing @@ -533,17 +534,26 @@ def test_print_ir(capfd): assert "multiply" in out -__TRACE_COUNTER__ = 0 +@tvm.instrument.pass_instrument +class PassCounter: + def __init__(self): + # Just setting a garbage value to test set_up callback + self.counts = 1234 + def enter_pass_ctx(self): + self.counts = 0 -def _tracer(module, info, is_before): - global __TRACE_COUNTER__ - if bool(is_before): - __TRACE_COUNTER__ += 1 + def exit_pass_ctx(self): + self.counts = 0 + + def run_before_pass(self, module, info): + self.counts += 1 + + def get_counts(self): + return self.counts def test_print_debug_callback(): - global __TRACE_COUNTER__ shape = (1, 2, 3) tp = relay.TensorType(shape, "float32") x = relay.var("x", tp) @@ -559,15 +569,20 @@ def test_print_debug_callback(): ] ) - assert __TRACE_COUNTER__ == 0 mod = tvm.IRModule({"main": func}) - with tvm.transform.PassContext(opt_level=3, trace=_tracer): + pass_counter = PassCounter() + with tvm.transform.PassContext(opt_level=3, instruments=[pass_counter]): + # Should be reseted when entering pass context + assert pass_counter.get_counts() == 0 mod = seq(mod) - # TODO(@jroesch): when we remove new fn pass behavior we need to remove - # change this back to 3 - assert __TRACE_COUNTER__ == 5 + # TODO(@jroesch): when we remove new fn pass behavior we need to remove + # change this back to match correct behavior + assert pass_counter.get_counts() == 6 + + # Should be cleanned up after exiting pass context + assert pass_counter.get_counts() == 0 if __name__ == "__main__": diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 6a33d14e38c8..3804b1496d05 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -273,14 +273,16 @@ def visit_constant(self, c): # An example is below. -def print_ir(mod, info, is_before): +@tvm.instrument.pass_instrument +class PrintIR: """Print the name of the pass, the IR, only before passes execute.""" - if is_before: + + def run_before_pass(self, mod, info): print("Running pass: {}", info) print(mod) -with tvm.transform.PassContext(opt_level=3, trace=print_ir): +with tvm.transform.PassContext(opt_level=3, instruments=[PrintIR()]): with tvm.target.Target("llvm"): # Perform the optimizations. mod = seq(mod)