Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/analysis/*.cc
src/relax/transform/*.cc
src/relax/backend/vm/*.cc
src/relax/backend/task_extraction.cc
src/relax/utils.cc
)

Expand Down
54 changes: 44 additions & 10 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* Similar to LLVM's pass manager, we designed the Relay/Relax pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay/Relax pass
* manager performs the IRModule -> IRModule transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* means to ease the development and testing of Relay/Relax passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
Expand Down Expand Up @@ -90,7 +90,16 @@ class PassContextNode : public Object {

/*! \brief A list of pass instrument implementations. */
Array<instrument::PassInstrument> instruments;

// TODO(@sunggg): Fix dependency issue in the header file and correct the types
// e.g., relax::trace, relax::database in tvm/relax/tuning_api.h
/*! \brief Trace stack for relax pass infra. */
mutable Array<ObjectRef> trace_stack;
/*! \brief List of passes to be traced. If not defined, make every pass traceable. */
Optional<Map<String, Bool>> make_traceable;
/*! \brief Number of evaluations conducted in the pass pipeline. */
mutable int num_evals{0};
/*! \brief Database for tuning API. */
Optional<ObjectRef> tuning_api_database;
PassContextNode() = default;

/*!
Expand Down Expand Up @@ -130,7 +139,27 @@ class PassContextNode : public Object {
v->Visit("instruments", &instruments);
v->Visit("config", &config);
v->Visit("diag_ctx", &diag_ctx);
v->Visit("trace_stack", &trace_stack);
v->Visit("make_traceable", &make_traceable);
v->Visit("num_evals", &num_evals);
v->Visit("tuning_api_daatabase", &tuning_api_database);
}

Array<ObjectRef> GetTraceStack() { return trace_stack; }
void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); }
void PopTrace() {
ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check.";
trace_stack.pop_back();
}
int GetTraceStackSize() { return trace_stack.size(); }
ObjectRef GetCurrentTrace() {
ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check.";
return trace_stack.back();
}
void SetNumEvals(int _num_evals) { num_evals = _num_evals; }
void IncNumEvals(int _num_evals) { num_evals += _num_evals; }

Optional<ObjectRef> GetTuningAPIDatabase() { return tuning_api_database; }

static constexpr const char* _type_key = "transform.PassContext";
static constexpr bool _type_has_method_sequal_reduce = false;
Expand Down Expand Up @@ -287,6 +316,9 @@ class PassInfoNode : public Object {
/*! \brief The name of an optimization/analysis pass. */
String name;

/*! \brief Boolean that tells whether this pass will be traced or not. */
bool traceable;

/*! \brief The passes that are required to perform the current pass. */
Array<String> required;

Expand All @@ -296,6 +328,7 @@ class PassInfoNode : public Object {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
v->Visit("traceable", &traceable);
}

static constexpr const char* _type_key = "transform.PassInfo";
Expand All @@ -314,16 +347,17 @@ class PassInfo : public ObjectRef {
* \param opt_level The optimization level
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
* \param traceable Boolean that tells whether the pass is traceable.
*/
TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required, bool traceable);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};

/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
* at different granularity of Relay/Relax nodes.
*/
class PassNode : public Object {
public:
Expand Down Expand Up @@ -396,7 +430,7 @@ class Pass : public ObjectRef {
};

/*!
* \brief The SequentialNode contains a set of passes that transform Relay
* \brief The SequentialNode contains a set of passes that transform Relay/Relax
* programs from one AST to another semantically equivalent one.
*
* One example of this level of pass is that the pass manager needs to correctly
Expand Down Expand Up @@ -489,9 +523,9 @@ class Sequential : public Pass {
*
* \return The created module pass.
*/
TVM_DLL Pass
CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, String name, Array<runtime::String> required);
TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
String name, Array<runtime::String> required, bool traceable = false);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
Expand Down
22 changes: 20 additions & 2 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ using DataflowBlock = tvm::relax::DataflowBlock;
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
* \param traceable Boolean variable whether the dataflowblock pass is traceable.
*
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);
int opt_level, String name, tvm::Array<String> required, bool traceable = false);

/*!
* \brief Create a dataflowblock pass.
Expand All @@ -58,12 +59,13 @@ TVM_DLL Pass CreateFunctionPass(
* \param opt_level The optimization level of the dataflowblock pass.
* \param name The name of the dataflowblock pass.
* \param required The list of the passes that the dataflowblock pass is dependent on.
* \param traceable Boolean variable whether the dataflowblock pass is traceable.
*
* \return The created dataflowblock pass.
*/
TVM_DLL Pass CreateDataflowBlockPass(
const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);
int opt_level, String name, tvm::Array<String> required, bool traceable = false);

/*!
* \brief Transform all dataflow structure to non-dataflow version.
Expand Down Expand Up @@ -93,6 +95,22 @@ TVM_DLL Pass CallTIRRewrite();
*/
TVM_DLL Pass RewriteDataflowReshape();

/*!
* \brief Bind params of function of the module to constant tensors.
*
* \param func_name The name of the function to bind parameters.
* \param params The parameters to bind.
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);

/*!
* \brief Fold constant expressions.
*
* \return The Pass.
*/
TVM_DLL Pass FoldConstant();
/*!
* \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
*
Expand Down
Loading