Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ

::

TVM_REGISTER_API("_ComputeOp")
TVM_REGISTER_GLOBAL("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0],
args[1],
Expand Down Expand Up @@ -174,7 +174,7 @@ The ``Build()`` function looks up the code generator for the given target in the

::

TVM_REGISTER_API("codegen.build_cuda")
TVM_REGISTER_GLOBAL("codegen.build_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCUDA(args[0]);
});
Expand Down
4 changes: 2 additions & 2 deletions docs/dev/relay_add_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ the arguments to the call node, as below.

.. code:: c

TVM_REGISTER_API("relay.op._make.add")
TVM_REGISTER_GLOBAL("relay.op._make.add")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
Expand All @@ -106,7 +106,7 @@ Including a Python API Hook
---------------------------

It is generally the convention in Relay, that functions exported
through ``TVM_REGISTER_API`` should be wrapped in a separate
through ``TVM_REGISTER_GLOBAL`` should be wrapped in a separate
Python function rather than called directly in Python. In the case
of the functions that produce calls to operators, it may be convenient
to bundle them, as in ``python/tvm/relay/op/tensor.py``, where
Expand Down
32 changes: 16 additions & 16 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ Python APIs to create a compilation pipeline using pass context.
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
/* Other fields are omitted. */

private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();

// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
Expand Down Expand Up @@ -225,7 +225,7 @@ cannot add or delete a function through these passes as they are not aware of
the global information.

.. code:: c++

class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Expand Down Expand Up @@ -319,7 +319,7 @@ favorably use Python APIs to create a specific pass object.
ModulePass CreateModulePass(std::string name,
int opt_level,
PassFunc pass_func);

SequentialPass CreateSequentialPass(std::string name,
int opt_level,
Array<Pass> passes,
Expand Down Expand Up @@ -347,14 +347,14 @@ registration.
auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});

auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, tvm::Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
auto mod = relay::ModuleNode::FromExpr(fx);

// Create a sequential pass.
tvm::Array<relay::transform::Pass> pass_seqs{
relay::transform::InferType(),
Expand All @@ -363,7 +363,7 @@ registration.
relay::transform::AlterOpLayout()
};
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);

// Create a pass context for the optimization.
auto ctx = relay::transform::PassContext::Create();
ctx->opt_level = 2;
Expand Down Expand Up @@ -421,7 +421,7 @@ Python when needed.
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

TVM_REGISTER_API("relay._transform.FoldConstant")
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

} // namespace transform
Expand Down Expand Up @@ -457,10 +457,10 @@ a certain scope.
def __enter__(self):
_transform.EnterPassContext(self)
return self

def __exit__(self, ptype, value, trace):
_transform.ExitPassContext(self)

@staticmethod
def current():
"""Return the current pass context."""
Expand Down Expand Up @@ -580,18 +580,18 @@ using ``Sequential`` associated with other types of passes.
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
func = relay.Function([x], z2)
# Customize the optimization pipeline.

# Customize the optimization pipeline.
seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.AlterOpLayout()
])

# Create a module to perform optimizations.
mod = relay.Module({"main": func})

# Users can disable any passes that they don't want to execute by providing
# a list, e.g. disabled_pass=["EliminateCommonSubexpr"].
with relay.build_config(opt_level=3):
Expand Down Expand Up @@ -629,7 +629,7 @@ For more pass infra related examples in Python and C++, please refer to

.. _Block: https://mxnet.incubator.apache.org/api/python/docs/api/gluon/block.html#gluon-block

.. _Relay module: https://docs.tvm.ai/langref/relay_expr.html#module-and-global-functions
.. _Relay module: https://docs.tvm.ai/langref/relay_expr.html#module-and-global-functions

.. _include/tvm/relay/transform.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h

Expand Down
1 change: 0 additions & 1 deletion include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "base.h"
#include "expr.h"
#include "lowered_func.h"
#include "api_registry.h"
#include "runtime/packed_func.h"

namespace tvm {
Expand Down
32 changes: 9 additions & 23 deletions include/tvm/api_registry.h → include/tvm/node/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,19 @@
*/

/*!
* \file tvm/api_registry.h
* \brief This file contains utilities related to
* the TVM's global function registry.
* \file tvm/node/env_func.h
* \brief Serializable global function.
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#ifndef TVM_NODE_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_

#include <tvm/node/reflection.h>

#include <string>
#include <utility>
#include "base.h"
#include "packed_func_ext.h"
#include "runtime/registry.h"

namespace tvm {
/*!
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
*
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)

namespace tvm {
/*!
* \brief Node container of EnvFunc
* \sa EnvFunc
Expand All @@ -54,7 +40,7 @@ class EnvFuncNode : public Object {
/*! \brief Unique name of the global function */
std::string name;
/*! \brief The internal packed function */
PackedFunc func;
runtime::PackedFunc func;
/*! \brief constructor */
EnvFuncNode() {}

Expand Down Expand Up @@ -154,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
};

} // namespace tvm
#endif // TVM_API_REGISTRY_H_
#endif // TVM_NODE_ENV_FUNC_H_
2 changes: 1 addition & 1 deletion include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_BASE_H_
#define TVM_RELAY_BASE_H_

#include <tvm/api_registry.h>

#include <tvm/ir/span.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_

#include <tvm/api_registry.h>

#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>

#include <tvm/ir.h>
#include <string>

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Registry {
*
* \code
*
* TVM_REGISTER_API("addone")
* TVM_REGISTER_GLOBAL("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
Expand All @@ -96,7 +96,7 @@ class Registry {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* TVM_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
Expand All @@ -120,7 +120,7 @@ class Registry {
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
Expand Down Expand Up @@ -148,7 +148,7 @@ class Registry {
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
Expand Down Expand Up @@ -181,7 +181,7 @@ class Registry {
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
Expand Down Expand Up @@ -221,7 +221,7 @@ class Registry {
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
Expand Down
32 changes: 17 additions & 15 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,31 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

#include <tvm/tensor.h>

namespace tvm {
namespace arith {

TVM_REGISTER_API("arith.intset_single_point")
TVM_REGISTER_GLOBAL("arith.intset_single_point")
.set_body_typed(IntSet::single_point);

TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_GLOBAL("arith.intset_vector")
.set_body_typed(IntSet::vector);

TVM_REGISTER_API("arith.intset_interval")
TVM_REGISTER_GLOBAL("arith.intset_interval")
.set_body_typed(IntSet::interval);


TVM_REGISTER_API("arith.DetectLinearEquation")
TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);

TVM_REGISTER_API("arith.DetectClipBound")
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound);

TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
Expand All @@ -55,36 +57,36 @@ TVM_REGISTER_API("arith.DeduceBound")
});


TVM_REGISTER_API("arith.DomainTouched")
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);

TVM_REGISTER_API("_IntervalSetGetMin")
TVM_REGISTER_GLOBAL("_IntervalSetGetMin")
.set_body_method(&IntSet::min);

TVM_REGISTER_API("_IntervalSetGetMax")
TVM_REGISTER_GLOBAL("_IntervalSetGetMax")
.set_body_method(&IntSet::max);

TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_GLOBAL("_IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);

TVM_REGISTER_API("_IntSetIsEverything")
TVM_REGISTER_GLOBAL("_IntSetIsEverything")
.set_body_method(&IntSet::is_everything);

ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}

TVM_REGISTER_API("arith._make_ConstIntBound")
TVM_REGISTER_GLOBAL("arith._make_ConstIntBound")
.set_body_typed(MakeConstIntBound);

ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}

TVM_REGISTER_API("arith._make_ModularSet")
TVM_REGISTER_GLOBAL("arith._make_ModularSet")
.set_body_typed(MakeModularSet);

TVM_REGISTER_API("arith._CreateAnalyzer")
TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
Expand Down
Loading