diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst index c2a1a1e106d5..e7fb3c114689 100644 --- a/docs/api/python/ir.rst +++ b/docs/api/python/ir.rst @@ -23,6 +23,14 @@ tvm.ir :autosummary: +tvm.instrument +-------------- +.. automodule:: tvm.instrument + :members: + :imported-members: + :autosummary: + + tvm.transform ------------- .. automodule:: tvm.transform diff --git a/docs/conf.py b/docs/conf.py index 45f5da670608..83fa8fd37ae9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -273,7 +273,12 @@ def git_describe_version(original_version): "tune_network_x86.py", "tune_network_cuda.py", ], - "dev": ["low_level_custom_pass.py", "use_pass_infra.py", "bring_your_own_datatypes.py"], + "dev": [ + "low_level_custom_pass.py", + "use_pass_infra.py", + "use_pass_instrument.py", + "bring_your_own_datatypes.py", + ], } diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index 67ef30a29504..8973679c3c55 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -109,7 +109,8 @@ configure the compilation options, including optimization level and required/disabled passes, etc. For instance, we may have a configuration which performs all passes at ``opt_level=3`` with some disabled passes using ``disabled_pass=xx`` provided by ``PassContext``. Now we could glob all passes -at ``opt_level=3`` and exclude those in the disabled pass list. +at ``opt_level=3`` and exclude those in the disabled pass list. ``PassContext`` +also provides a way to instrument all passes. See section :ref:`pass_instrument_cpp_backend`. This class is designed for users to conveniently write the Python ``with`` syntax to perform optimizations under a certain configuration. In addition, the @@ -123,16 +124,22 @@ Python APIs to create a compilation pipeline using pass context. class PassContextNode : public Object { public: - ErrorReporter err_reporter; int opt_level{2}; tvm::Array required_pass; tvm::Array disabled_pass; + mutable Optional diag_ctx; + Map config; + Array instruments; }; class PassContext : public NodeRef { public: TVM_DLL static PassContext Create(); TVM_DLL static PassContext Current(); + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; /* Other fields are omitted. */ private: @@ -338,7 +345,7 @@ favorably use Python APIs to create a specific pass object. Pass Sequential(tvm::Array passes, PassInfo pass_info); Pass Registration -~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^ We've covered the concept of different level of passes and the context used for compilation. It would be interesting to see how easily users can register @@ -389,6 +396,148 @@ To allow other C++ modules to apply this pass, we declare a free function in TVM_DLL Pass FoldConstant(); +.. _pass_instrument_cpp_backend: + +Pass Instrument +^^^^^^^^^^^^^^^ + +Pass Instrument is a mechanism to analyze the pass itself. For example, +we can use the infrastructure to know how much time and memory a pass requires +or how a pass can transform the IR module. + +We introduce four instrument points in the life-cycle of ``PassContext``. + +.. code:: c++ + + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; + +``InstrumentEnterPassContext`` is called immediately when entering the scope +of the ``PassContext`` instance. + +``InstrumentExitPassContext`` is called when leaving the scope of ``PassContext``, +or exceptions occur during the execution of passes. +This method is also called when instruments is being overriden by ``override_instruments`` in :py:class:`tvm.transform.PassContext`. +See :ref:`pass_instrument_overriden`. + +``InstrumentBeforePass`` is called before execution. +``InstrumentAfterPass`` is called after execution if the pass should be run. The behavior is like: + +.. code:: c++ + + if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) { + new_ir_module = run_pass(ir_module, pass_ctx); + pass_ctx.InstrumentAfterPass(new_ir_module, pass_info); + return new_ir_module; + } + +The ``PassInstrument`` interface allow you to run arbitrary code inside above four methods. +Multiple ``PassInstrument`` instances can be registed into a single +``PassContext``. ``PassInstrument`` instances are called sequentially in the order of +``instruments`` argument passed to ``PassContext``. + +``PassInstrument`` provides following interfaces: + +.. code:: c++ + + namespace instrument { + + class PassInstrumentNode : public Object { + public: + String name; + virtual void EnterPassContext() const = 0; + virtual void ExitPassContext() const = 0; + virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; + /* Other fields are omitted. */ + }; + + class PassInstrument : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); + }; + + } // namespace instrument + +Python frontend are provided to implement ``PassInstrument`` quickly. See :ref:`pass_instrument_py_frontend`. + +Within a ``PassContext``, the call sequence of a ``PassInstrument`` instance is like: + +:: + + with PassContext(instruments=[pi]) # pi = a PassInstrument implementation. + pi.EnterPassContext() + + if pi.ShouldRun(Pass1): + pi.RunBeforePass() + Pass1() + pi.RunAfterPass() + + if pi.ShouldRun(Pass2): + pi.RunBeforePass() + Pass2() + pi.RunAfterPass() + + pi.ExitPassContext() + +Here is a brief introduction of relations between ``PassInstrument`` interfaces +and ``PassContext`` methods. See (`src/ir/transform.cc`_) for more details. + +- ``InstrumentEnterPassContext`` + + * ``EnterPassContext()`` is executed in the order of ``instruments`` passed to the ``PassContext``. + * When an exception raises, ``PassContext`` disable the pass instrumentation + by clearing all registered ``PassInstrument`` instances. + * Then ``PassContext`` execute ``ExitPassContext()`` method of each ``PassInstrument`` + instances which successfully finished ``EnterPassContext()`` + * For example, if ``PassInstrument`` A, B, and C are registered to a ``PassContext`` + and A finished ``EnterPassContext()`` while B throws an exception, then C + is never executed; ``ExitPassContext()`` of A is executed. + +- ``InstrumentExitPassContext`` + + * ``ExitPassContext()`` of each ``PassInstrument`` instances are executed in + the order of ``instruments`` passed to the ``PassContext``. + * While an exception occurs, ``instruments`` is cleared. + * ``PassInstrument`` Instances registered after the one throwing exceptions do not execute ``ExitPassContext``. + +- ``InstrumentBeforePass`` + + * ``ShouldRun`` is executed if the pass is not listed as a required pass. + * ``RunBeforePass`` is executed in the order of ``instruments`` if the pass is not blocked by ``ShouldRun``. + * Note that ``InstrumentBeforePass`` returns a boolean indicating whether or not the pass should be run. + * When an exception occur, it is thrown immediately. + We rely on Python Context Manager to exit ``PassContext`` safely + (meaning ``ExitPassContext`` of each instruments will be run. For C++, please refer to `include/tvm/support/with.h`_.) + +- ``InstrumentAfterPass`` + + * ``RunAfterPass`` is executed in the order of ``instruments`` passed to the ``PassContext``. + * When an exception occur, it is thrown immediately. + We rely on Python Context Manager or ``With`` class(`include/tvm/support/with.h`_) to exit ``PassContext`` safely + +Built-in Instrument +^^^^^^^^^^^^^^^^^^^ + +There are several built-in instruments. Those marked with *TODO* are not implemented yet. + +- PassTimingInstrument (see `src/ir/instrument.cc`_) + + * Profile the execution time of passes. + +- PrintIRBefore(TODO) + + * Print the IR module before the pass transforms it. :py:func:`tvm.transform.PrintIR` + can also serve this purpose if we insert it around passes. However, + with the ``PassInstrument``, we don't need to modify the sequence of passes. + +- PrintAfter(TODO) + + * Print the IR module after the pass transforms it. + Python Frontend ~~~~~~~~~~~~~~~ @@ -526,16 +675,78 @@ decorators and then invoke it. For more examples about how to customize your own optimization pipeline and debug Relay and tir passes, please refer to the `use pass infra`_ tutorial. + +.. _pass_instrument_py_frontend: + +Pass Instrument +^^^^^^^^^^^^^^^ + +One can implement a ``PassInstrument`` by using the ``pass_instrument`` +decorator(`python/tvm/ir/instrument.py`_) on a class implementing following methods. +Note that it is recommended to use the ``pass_instrument`` decorator to implement +``PassInstrument``, instead of overriding or subclassing. + +- ``enter_pass_ctx`` + + * This method is run when entering ``PassContext``. + +- ``exit_pass_ctx`` + + * This method is run when exiting ``PassContext``. + +- ``should_run`` + + * This method is run before a pass is executed, returning a boolean + indicating whether or not the pass should be run. + +- ``run_before_pass`` + + * If a pass should be run, this method is run just before pass execution. + +- ``run_after_pass`` + + * This method is run right after a pass has been executed. + +``PassInstrument`` instances can be registered through ``instruments`` argument in +:py:class:`tvm.transform.PassContext`. + +`use pass instrument`_ tutorial provides examples for how to implement ``PassInstrument`` with Python APIs. + +.. _pass_instrument_overriden: + +Override Instruments in Current PassContext +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``override_instruments`` method is provided to override the ``instruments`` of current ``PassContext``. +For example, if passes are run without explicitly creating a new ``PassContext``, +one can still register ``PassInstrument`` into the global ``PassContext`` by: + +.. code:: python + + cur_pass_ctx = tvm.transform.PassContext.current() + # override PassInstrument instances + cur_pass_ctx.override_instruments([pass_inst]) + mod = pass_seq(mod) + result = pass_inst.get_result() + +Note that when ``override_instruments`` is called, the ``exit_pass_ctx`` method of +old ``PassInstrument`` instances are called. Then the ``enter_pass_ctx`` method of +new ``PassInstrument`` are called. + .. _Sequential: https://pytorch.org/docs/stable/nn.html?highlight=sequential#torch.nn.Sequential .. _Block: https://mxnet.apache.org/api/python/docs/api/gluon/block.html#gluon-block .. _include/tvm/ir/transform.h: https://github.com/apache/tvm/blob/main/include/tvm/ir/transform.h +.. _include/tvm/support/with.h: https://github.com/apache/tvm/blob/main/include/tvm/support/with.h + .. _src/relay/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/relay/ir/transform.cc .. _src/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/ir/transform.cc +.. _src/ir/instrument.cc: https://github.com/apache/tvm/blob/main/src/ir/instrument.cc + .. _src/relay/transforms/fold_constant.cc: https://github.com/apache/tvm/blob/main/src/relay/transforms/fold_constant.cc .. _python/tvm/relay/transform/transform.py: https://github.com/apache/tvm/blob/main/python/tvm/relay/transform/transform.py @@ -544,6 +755,10 @@ optimization pipeline and debug Relay and tir passes, please refer to the .. _python/tvm/ir/transform.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/transform.py +.. _python/tvm/ir/instrument.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/instrument.py + .. _src/tir/transforms/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tir/transforms/unroll_loop.cc .. _use pass infra: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py + +.. _use pass instrument: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_instrument.py diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index c322f2bef3fc..1948a6787eac 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -30,11 +30,8 @@ class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. Users don't need to interact with this class directly. - Instead, a `PassInstrument` instance should be created through `pass_instrument`. - - See Also - -------- - `pass_instrument` + Instead, a `PassInstrument` instance should be created through + :py:func:`pass_instrument` """ @@ -91,13 +88,14 @@ def pass_instrument(pi_cls=None): Parameters ---------- - pi_class : + pi_class : class + Instrument class. See example below. Examples -------- - The following code block decorates a pass instrument class. .. code-block:: python + @tvm.instrument.pass_instrument class SkipPass: def __init__(self, skip_pass_name): @@ -155,5 +153,17 @@ def render(): ------- string : string The rendered string result of time profiles + + Examples + -------- + + .. code-block:: python + + timing_inst = PassTimingInstrument() + with tvm.transform.PassContext(instruments=[timing_inst]): + relay_mod = relay.transform.InferType()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + # before exiting the context, get profile results. + profiles = timing_inst.render() """ return _ffi_instrument_api.RenderTimePassProfiles() diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 3a3ac16be677..eb31d58b4428 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -107,8 +107,8 @@ def __exit__(self, ptype, value, trace): def override_instruments(self, instruments): """Override instruments within this PassContext. - If there are existing instruments, their exit_pass_ctx callbacks are called. - Then switching to new instruments and calling new enter_pass_ctx callbacks. + If there are existing instruments, their ``exit_pass_ctx`` callbacks are called. + Then switching to new instruments and calling new ``enter_pass_ctx`` callbacks. instruments : Sequence[PassInstrument] The list of pass instrument implementations. diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 86283fd31819..4122d4fad9e1 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -243,9 +243,6 @@ def __init__(self, id): def exit_pass_ctx(self): events.append(self.id + " exit ctx") - def exit_pass_ctx(self): - events.append(self.id + " exit ctx") - @pass_instrument class PIBroken(PI): def __init__(self, id): diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 3804b1496d05..468c4d40b942 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -261,16 +261,18 @@ def visit_constant(self, c): ] ) +############################################################################### # By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will # dump out the module IR when ``FoldConstant`` is done. Users can plug in this # pass after any pass they want to debug for viewing the optimization effect. # -# There is a more flexible debugging mechanism also exposed by the build configuration -# object. One can pass a tracing function which can be used to execute arbitrary code -# before and/or after each pass. A tracing function will receive a :py::class:`tvm.IRModule`, -# a :py:class:`tvm.transform.PassInfo` object, -# and a boolean indicating whether you are executing before, or after a pass. -# An example is below. +# There is a more flexible debugging mechanism. One can implement a ``PassInstrument`` +# class to execute arbitrary code not only before and/or after each pass but also +# at entering/exiting ``PassContext``. See :ref:`pass_instrument_cpp_backend` +# for more details. +# +# Here we use :py::func`tvm.instrument.pass_instrument` decorator to implement +# a PassInsturment class printing IR before execution of each passes: @tvm.instrument.pass_instrument diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py new file mode 100644 index 000000000000..3369304a651d --- /dev/null +++ b/tutorials/dev/use_pass_instrument.py @@ -0,0 +1,372 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +.. _tutorial-use-pass-instrument: + +How to Use TVM Pass Instrument +============================== +**Author**: `Chi-Wei Wang `_ + +As more and more passes are implemented, it becomes useful to instrument +pass execution, analyze per-pass effects, and observe various events. + +We can instrument passes by providing a list of :py:class:`tvm.ir.instrument.PassInstrument` +instances to :py:class:`tvm.transform.PassContext`. We provide a pass instrument +for collecting timing information (:py:class:`tvm.ir.instrument.PassTimingInstrument`), +but an extension mechanism is available via the :py:func:`tvm.instrument.pass_instrument` decorator. + +This tutorial demostrates how developers can use ``PassContext`` to instrument +passes. Please also refer to the :ref:`pass-infra`. +""" +import tvm +import tvm.relay as relay +from tvm.relay.testing import resnet +from tvm.contrib.download import download_testdata +from tvm.relay.build_module import bind_params_by_name +from tvm.ir.instrument import ( + PassTimingInstrument, + pass_instrument, +) + + +############################################################################### +# Create An Example Relay Program +# ------------------------------- +# We use pre-defined resnet-18 network in Relay. +batch_size = 1 +num_of_image_class = 1000 +image_shape = (3, 224, 224) +output_shape = (batch_size, num_of_image_class) +relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape) +print("Printing the IR module...") +print(relay_mod.astext(show_meta_data=False)) + + +############################################################################### +# Create PassContext With Instruments +# ----------------------------------- +# To run all passes with an instrument, pass it via the ``instruments`` argument to +# the ``PassContext`` constructor. A built-in ``PassTimingInstrument`` is used to +# profile the execution time of each passes. +timing_inst = PassTimingInstrument() +with tvm.transform.PassContext(instruments=[timing_inst]): + relay_mod = relay.transform.InferType()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + # before exiting the context, get profile results. + profiles = timing_inst.render() +print("Printing results of timing profile...") +print(profiles) + + +############################################################################### +# Use Current PassContext With Instruments +# ---------------------------------------- +# One can also use the current ``PassContext`` and register +# ``PassInstrument`` instances by ``override_instruments`` method. +# Note that ``override_instruments`` executes ``exit_pass_ctx`` method +# if any instrument already exists. Then it switches to new instruments +# and calls ``enter_pass_ctx`` method of new instruments. +# Refer to following sections and :py:func:`tvm.instrument.pass_instrument` for these methods. +cur_pass_ctx = tvm.transform.PassContext.current() +cur_pass_ctx.override_instruments([timing_inst]) +relay_mod = relay.transform.InferType()(relay_mod) +relay_mod = relay.transform.FoldScaleAxis()(relay_mod) +profiles = timing_inst.render() +print("Printing results of timing profile...") +print(profiles) + + +############################################################################### +# Register empty list to clear existing instruments. +# +# Note that ``exit_pass_ctx`` of ``PassTimingInstrument`` is called. +# Profiles are cleared so nothing is printed. +cur_pass_ctx.override_instruments([]) +# Uncomment the call to .render() to see a warning like: +# Warning: no passes have been profiled, did you enable pass profiling? +# profiles = timing_inst.render() + + +############################################################################### +# Create Customized Instrument Class +# ---------------------------------- +# A customized instrument class can be created using the +# :py:func:`tvm.instrument.pass_instrument` decorator. +# +# Let's create an instrument class which calculates the change in number of +# occurrences of each operator caused by each pass. We can look at ``op.name`` to +# find the name of each operator. And we do this before and after passes to calculate the difference. + + +@pass_instrument +class RelayCallNodeDiffer: + def __init__(self): + self._op_diff = [] + # Passes can be nested. + # Use stack to make sure we get correct before/after pairs. + self._op_cnt_before_stack = [] + + def enter_pass_ctx(self): + self._op_diff = [] + self._op_cnt_before_stack = [] + + def exit_pass_ctx(self): + assert len(self._op_cnt_before_stack) == 0, "The stack is not empty. Something wrong." + + def run_before_pass(self, mod, info): + self._op_cnt_before_stack.append((info.name, self._count_nodes(mod))) + + def run_after_pass(self, mod, info): + # Pop out the latest recorded pass. + name_before, op_to_cnt_before = self._op_cnt_before_stack.pop() + assert name_before == info.name, "name_before: {}, info.name: {} doesn't match".format( + name_before, info.name + ) + cur_depth = len(self._op_cnt_before_stack) + op_to_cnt_after = self._count_nodes(mod) + op_diff = self._diff(op_to_cnt_after, op_to_cnt_before) + # only record passes causing differences. + if op_diff: + self._op_diff.append((cur_depth, info.name, op_diff)) + + def get_pass_to_op_diff(self): + """ + return [ + (depth, pass_name, {op_name: diff_num, ...}), ... + ] + """ + return self._op_diff + + @staticmethod + def _count_nodes(mod): + """Count the number of occurrences of each operator in the module""" + ret = {} + + def visit(node): + if isinstance(node, relay.expr.Call): + if hasattr(node.op, "name"): + op_name = node.op.name + else: + # Some CallNode may not have 'name' such as relay.Function + return + ret[op_name] = ret.get(op_name, 0) + 1 + + relay.analysis.post_order_visit(mod["main"], visit) + return ret + + @staticmethod + def _diff(d_after, d_before): + """Calculate the difference of two dictionary along their keys. + The result is values in d_after minus values in d_before. + """ + ret = {} + key_after, key_before = set(d_after), set(d_before) + for k in key_before & key_after: + tmp = d_after[k] - d_before[k] + if tmp: + ret[k] = d_after[k] - d_before[k] + for k in key_after - key_before: + ret[k] = d_after[k] + for k in key_before - key_after: + ret[k] = -d_before[k] + return ret + + +############################################################################### +# Apply Passes and Multiple Instrument Classes +# -------------------------------------------- +# We can use multiple instrument classes in a ``PassContext``. +# However, it should be noted that instrument methods are executed sequentially, +# obeying the order of ``instruments`` argument. +# So for instrument classes like ``PassTimingInstrument``, it is inevitable to +# count-up the execution time of other instrument classes to the final +# profile result. +call_node_inst = RelayCallNodeDiffer() +desired_layouts = { + "nn.conv2d": ["NHWC", "HWIO"], +} +pass_seq = tvm.transform.Sequential( + [ + relay.transform.FoldConstant(), + relay.transform.ConvertLayout(desired_layouts), + relay.transform.FoldConstant(), + ] +) +relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params) +# timing_inst is put after call_node_inst. +# So the execution time of ``call_node.inst.run_after_pass()`` is also counted. +with tvm.transform.PassContext(opt_level=3, instruments=[call_node_inst, timing_inst]): + relay_mod = pass_seq(relay_mod) + profiles = timing_inst.render() +# Uncomment the next line to see timing-profile results. +# print(profiles) + + +############################################################################### +# We can see how many CallNode increase/decrease per op type. +from pprint import pprint + +print("Printing the change in number of occurrences of each operator caused by each pass...") +pprint(call_node_inst.get_pass_to_op_diff()) + + +############################################################################### +# Exception Handling +# ------------------ +# Let's see what happens if an exception occurs in a method of a ``PassInstrument``. +# +# Define ``PassInstrument`` classes which raise exceptions in enter/exit ``PassContext``: +class PassExampleBase: + def __init__(self, name): + self._name = name + + def enter_pass_ctx(self): + print(self._name, "enter_pass_ctx") + + def exit_pass_ctx(self): + print(self._name, "exit_pass_ctx") + + def should_run(self, mod, info): + print(self._name, "should_run") + return True + + def run_before_pass(self, mod, pass_info): + print(self._name, "run_before_pass") + + def run_after_pass(self, mod, pass_info): + print(self._name, "run_after_pass") + + +@pass_instrument +class PassFine(PassExampleBase): + pass + + +@pass_instrument +class PassBadEnterCtx(PassExampleBase): + def enter_pass_ctx(self): + print(self._name, "bad enter_pass_ctx!!!") + raise ValueError("{} bad enter_pass_ctx".format(self._name)) + + +@pass_instrument +class PassBadExitCtx(PassExampleBase): + def exit_pass_ctx(self): + print(self._name, "bad exit_pass_ctx!!!") + raise ValueError("{} bad exit_pass_ctx".format(self._name)) + + +############################################################################### +# If an exception occurs in ``enter_pass_ctx``, ``PassContext`` will disable the pass +# instrumentation. And it will run the ``exit_pass_ctx`` of each ``PassInstrument`` +# which successfully finished ``enter_pass_ctx``. +# +# In following example, we can see ``exit_pass_ctx`` of `PassFine_0` is executed after exception. +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadEnterCtx("PassBadEnterCtx"), + PassFine("PassFine_1"), + ] +) +try: + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Exceptions in ``PassInstrument`` instances cause all instruments of the current ``PassContext`` +# to be cleared, so nothing is printed when ``override_instruments`` is called. +demo_ctx.override_instruments([]) # no PassFine_0 exit_pass_ctx printed....etc + +############################################################################### +# If an exception occurs in ``exit_pass_ctx``, then the pass instrument is disabled. +# Then exception is propagated. That means ``PassInstrument`` instances registered +# after the one throwing the exception do not execute ``exit_pass_ctx``. +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadExitCtx("PassBadExitCtx"), + PassFine("PassFine_1"), + ] +) +try: + # PassFine_1 execute enter_pass_ctx, but not exit_pass_ctx. + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass`` +# are not handled explicitly -- we rely on the context manager (the ``with`` syntax) +# to exit ``PassContext`` safely. +# +# We use ``run_before_pass`` as an example: +@pass_instrument +class PassBadRunBefore(PassExampleBase): + def run_before_pass(self, mod, pass_info): + print(self._name, "bad run_before_pass!!!") + raise ValueError("{} bad run_before_pass".format(self._name)) + + +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadRunBefore("PassBadRunBefore"), + PassFine("PassFine_1"), + ] +) +try: + # All exit_pass_ctx are called. + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Also note that pass instrumentation is not disable. So if we call +# ``override_instruments``, the ``exit_pass_ctx`` of old registered ``PassInstrument`` +# is called. +demo_ctx.override_instruments([]) + +############################################################################### +# If we don't wrap pass execution with ``with`` syntax, ``exit_pass_ctx`` is not +# called. Let try this with current ``PassContext``: +cur_pass_ctx = tvm.transform.PassContext.current() +cur_pass_ctx.override_instruments( + [ + PassFine("PassFine_0"), + PassBadRunBefore("PassBadRunBefore"), + PassFine("PassFine_1"), + ] +) + +############################################################################### +# Then call passes. ``exit_pass_ctx`` is not executed after the exception, +# as expectation. +try: + # No ``exit_pass_ctx`` got executed. + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Clear instruments. +cur_pass_ctx.override_instruments([])