Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions include/tvm/ir/instrument.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/instrument.h
*
* This file introduces a pass instrument infrastructure, inspired by LLVM and MLIR.
* It inserts instrumentation points around passes.
*/
#ifndef TVM_IR_INSTRUMENT_H_
#define TVM_IR_INSTRUMENT_H_

#include <tvm/node/reflection.h>
#include <tvm/runtime/container.h>

#include <utility>
#include <vector>

namespace tvm {

class IRModule;

// Forward class for PassInstrumentNode methods
namespace transform {
class PassInfo;
} // namespace transform

namespace instrument {

/*!
* \brief PassInstrumentNode forms an instrument implementation.
* It provides API for users to register callbacks at different instrumentation points.
*
* Within a PassContext, call sequence of a PassInstrument implementation is like:
*
* with PassContext(instruments=[pi]): # pi = a PassInstrument implementation
* pi.EnterPassContext()
*
* if pi.ShouldRun(Pass1):
* pi.RunBeforePass()
* Pass1()
* pi.RunAfterPass()
*
* if pi.ShouldRun(Pass2):
* pi.RunBeforePass()
* Pass2()
* pi.RunAfterPass()
*
* pi.ExitPassContext()
*
* `EnterPassContext` and `ExitPassContext` are only called once when entering/exiting a
* PassContext. `ShouldRun`, `RunBeforePass` and `RunAfterPass` are called multiple times depending
* on how many passes.
*
* If there are multiple pass instrumentations provided, the instrument points are the same.
* PassInstrument implementations' callbacks are called in order:
*
* with PassContext(instruments=[pi1, pi2]): # pi1, pi2 = two distinct PassInstrument impls
* pi.EnterPassContext() for pi in instruments
*
* should_run = all([pi.ShoudRun(Pass1) for pi in instruments)])
* if (should_run)
* pi.RunBeforePass() for pi in instruments
* Pass1()
* pi.RunAfterPass() for pi in instruments
*
* should_run = all([pi.ShouldRun(Pass2) for pi in instruments)])
* if (should_run)
* pi.RunBeforePass() for pi in instruments
* Pass2()
* pi.RunAfterPass() for pi in instruments
*
* pi.ExitPassContext() for pi in instruments
*
* Note:
* 1. Assume there is no dependency between PassInstrument implementations in `instruments` .
* 2. `EnterPassContext` and `ExitPassContext` have `with` behavior (see PassContext and its FFI):
* If there is any exception raised in `ShouldRun()`, `RunBeforePass()`, `RunAfterPass()` and
* `Pass()`, `ExitPassContext()` is still called.
* 3. In mutiple PassInstrument instances scenario, callbacks are called in order:
* If one throws exceptions, remainings will not be called.
*
* \sa PassInstrument
* \sa src/ir/transform.cc
*/
class PassInstrumentNode : public Object {
public:
/*! \brief Name of this pass instrument object. */
String name;

virtual ~PassInstrumentNode() {}

/*! \brief Instrument when entering PassContext. Called once within a PassContext. */
virtual void EnterPassContext() const = 0;

/*! \brief Instrument when exiting PassContext. Called once within a PassContext. */
virtual void ExitPassContext() const = 0;

/*!
* \brief Determine whether to run the pass or not. Called multiple times depend on number of
* passes.
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*
* \return true to run the pass; false to skip the pass.
*/
virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;

/*!
* \brief Instrument before pass run. Called multiple times depend on number of passes.
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*/
virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;

/*!
* \brief Instrument after pass run. Called multiple time depend on number of passes.
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*/
virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;

void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

static constexpr const char* _type_key = "instrument.PassInstrument";
TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object);
};

/*!
* \brief Managed reference class for PassInstrumentNode
* \sa PassInstrumentNode
*/
class PassInstrument : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};

} // namespace instrument
} // namespace tvm

#endif // TVM_IR_INSTRUMENT_H_
54 changes: 38 additions & 16 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

#include <tvm/ir/diagnostic.h>
#include <tvm/ir/error.h>
#include <tvm/ir/instrument.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/support/with.h>
Expand All @@ -68,15 +69,6 @@
namespace tvm {
namespace transform {

// Forward declare for TraceFunc.
class PassInfo;

/*!
* \brief A callback for tracing passes, useful for debugging and logging.
*/
using TraceFunc =
runtime::TypedPackedFunc<void(const IRModule& ir_module, const PassInfo& ctx, bool is_before)>;

/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
Expand All @@ -95,8 +87,9 @@ class PassContextNode : public Object {
mutable Optional<DiagnosticContext> diag_ctx;
/*! \brief Pass specific configurations. */
Map<String, ObjectRef> config;
/*! \brief Trace function to be invoked before and after each pass. */
TraceFunc trace_func;

/*! \brief A list of pass instrument implementations. */
Array<instrument::PassInstrument> instruments;

PassContextNode() = default;

Expand Down Expand Up @@ -134,6 +127,7 @@ class PassContextNode : public Object {
v->Visit("opt_level", &opt_level);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
v->Visit("instruments", &instruments);
v->Visit("config", &config);
v->Visit("diag_ctx", &diag_ctx);
}
Expand Down Expand Up @@ -189,12 +183,40 @@ class PassContext : public ObjectRef {
TVM_DLL static PassContext Current();

/*!
* \brief Apply the tracing functions of the context to the module, with the info.
* \param module The IRModule to trace.
* \brief Call instrument implementations' callbacks when entering PassContext.
* The callbacks are called in order, and if one raises an exception, the rest will not be
* called.
*/
TVM_DLL void InstrumentEnterPassContext();

/*!
* \brief Call instrument implementations' callbacks when exiting PassContext.
* The callbacks are called in order, and if one raises an exception, the rest will not be
* called.
*/
TVM_DLL void InstrumentExitPassContext();

/*!
* \brief Call instrument implementations' callbacks before a pass run.
* The callbacks are called in order, and if one raises an exception, the rest will not be
* called.
*
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
* \param is_before Indicated whether the tracing is before or after a pass.
*
* \return false: the pass is skipped; true: the pass runs.
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;

/*!
* \brief Call instrument implementations callbacks after a pass run.
* The callbacks are called in order, and if one raises an exception, the rest will not be
* called.
*
* \param mod The module that an optimization pass runs on.
* \param info The pass information.
*/
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;

/*!
* \brief Check whether a pass is enabled.
Expand Down Expand Up @@ -275,7 +297,7 @@ class PassInfoNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};

/*
/*!
* \brief Managed reference class for PassInfoNode
* \sa PassInfoNode
*/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import instrument
from .ir import container
from . import ir

Expand Down Expand Up @@ -67,7 +68,6 @@
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel


# NOTE: This file should be python2 compatible so we can
# raise proper error message when user run the package using
# an older version of the python
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@
from .container import Array, Map

from . import transform
from . import instrument
from . import diagnostics
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.relay
from tvm.relay import op
"""FFI APIs for tvm.instrument"""
import tvm._ffi


def test_pass_profiler():
x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
e1 = op.add(x, y)
e2 = op.subtract(x, z)
e3 = op.multiply(e1, e1 / e2)
mod = tvm.IRModule.from_expr(e3 + e2)

tvm.transform.enable_pass_profiling()

mod = tvm.relay.transform.AnnotateSpans()(mod)
mod = tvm.relay.transform.ToANormalForm()(mod)
mod = tvm.relay.transform.InferType()(mod)

profiles = tvm.transform.render_pass_profiles()
assert "AnnotateSpans" in profiles
assert "ToANormalForm" in profiles
assert "InferType" in profiles

tvm.transform.clear_pass_profiles()
tvm.transform.disable_pass_profiling()
tvm._ffi._init_api("instrument", __name__)
Loading