Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ include $(config)
# specify tensor path
.PHONY: clean all test doc

all: lib/libtvm.a lib/libtvm.so
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a

LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a

SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)

RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
RUNTIME_DEP = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))

ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)

export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
Expand Down Expand Up @@ -77,15 +82,18 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@


lib/libtvm.a: $(ALL_DEP)
lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

lib/libtvm.so: $(ALL_DEP)
lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)

$(LIB_HALIDE_IR): LIBHALIDEIR

LIBHALIDEIR:
Expand Down
74 changes: 8 additions & 66 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
@@ -1,85 +1,27 @@
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2017 by Contributors
* \file api_registry.h
* \brief This file defines the TVM API registry.
*
* The API registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_API(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
* \brief This files include necessary headers to
* be used to register an global API function.
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_

#include <dmlc/base.h>
#include <string>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./packed_func_ext.h"

namespace tvm {

/*! \brief Utility to register API. */
class APIRegistry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)

private:
/*! \brief name of the function */
std::string name_;
};
#include "./runtime/registry.h"

/*!
* \brief Get API function by name.
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
*
* \param name The name of the function.
* \return the corresponding API function.
* \note It is really PackedFunc::GetGlobal under the hood.
*/
inline PackedFunc GetAPIFunc(const std::string& name) {
return PackedFunc::GetGlobal(name);
}

#define _TVM_REGISTER_VAR_DEF_ \
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_

/*!
* \brief Register API function globally.
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) \
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
::tvm::APIRegistry::__REGISTER__(#OpName)
} // namespace tvm
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)

#endif // TVM_API_REGISTRY_H_
42 changes: 11 additions & 31 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#include "./base.h"
#include "./expr.h"
#include "./lowered_func.h"
#include "./api_registry.h"
#include "./runtime/packed_func.h"


namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
Expand All @@ -22,41 +22,21 @@ using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
* \param device_funcs The additional device functions
* \return A packed function representing the func.
*/
PackedFunc BuildStackVM(
LoweredFunc func,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);

/*!
* \brief Build a LLVM VM function, this is still beta
* \param func The LoweredFunc to be build
* \return A packed function representing the func.
*/
PackedFunc BuildLLVM(LoweredFunc func);

/*!
* \brief Build a CUDA function with NVRTC
* \brief Build a module from array of lowered function.
* \param funcs The functions to be built.
* \param target The target to be built.
* \return The builded module.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
* \note Calls global API function "_codegen_build_" + target
*/
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode);
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target);

/*!
* \brief Build a OpenCL function.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
* \param target The target to be queried.
* \return Whether target is enabled.
*/
PackedFunc BuildOpenCL(Array<LoweredFunc> fsplits, std::string host_mode);
bool TargetEnabled(const std::string& target);

} // namespace codegen
} // namespace tvm
Expand Down
19 changes: 5 additions & 14 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,14 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* int tvm_call_global(name, TVMValue* args) {
* PackedFunc f = PackedFunc::GetGlobal(name);
* f (args, type_code_of(args), len(args));
* int tvm_call_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_global = "tvm_call_global";
/*!
* \brief See pesudo code
*
* int tvm_call_device(name, TVMValue* args) {
* PackedFunc df = CodeGenEnv->GetDevice(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_device = "tvm_call_device";
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
* \brief See pesudo code
*
Expand Down
7 changes: 5 additions & 2 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,15 @@ Stmt LiftAllocate(Stmt stmt);
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
Expand All @@ -167,7 +170,7 @@ Stmt LiftAllocate(Stmt stmt);
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
int num_unpacked_args);

/*!
* \brief Count number of undefined vars in f.
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class LoweredFuncNode : public FunctionBaseNode {
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
Expand All @@ -88,6 +90,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("body", &body);
}

Expand Down
Loading