Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlpack
Submodule dlpack updated 1 files
+19 −2 include/dlpack/dlpack.h
33 changes: 33 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,39 @@ constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;

/*! \brief The default device allocated to an operator */
constexpr DLDeviceType kDLDefaultDevice = kDLOpenCL;

struct DLDeviceTypeHash {
template <typename T> int operator()(T dev) const {
return static_cast<int>(dev);
}
};

/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kExtDev: return "ext_dev";
default: {
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
}
}
}

/*!
* \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management.
Expand Down
18 changes: 18 additions & 0 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,24 @@ NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
const char** pass_names,
GraphHandle *dst);

/*!
* \brief Apply graph annotation pass to the graph.
* \param src The source graph handle.
* \param num_ops The number of operators.
* \param op_names The name of each operator.
* \param out Graph handle of the updated graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNAnnotateGraph(GraphHandle src, nn_uint num_ops,
const char** op_names, GraphHandle* out);
/*!
* \brief Check if a graph has a certain attribute.
* \param handle The source graph handle.
* \param key The name of the attribute to check.
* \param out Symbol handle of the updated graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphHasJSONAttr(GraphHandle handle, const char *key, int *has);
#ifdef __cplusplus
} /* end extern "C" */
#endif
Expand Down
124 changes: 124 additions & 0 deletions nnvm/include/nnvm/graph_annotate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*!
* Copyright (c) 2018 by Contributors
* \file graph_annotate.h
* \brief Define rules to annotate a graph. The annotation rules/properties is
* implemented similarly to the selection of subgraph nodes in mxnet.
*/
#ifndef NNVM_GRAPH_ANNOTATE_H_
#define NNVM_GRAPH_ANNOTATE_H_

#include <dlpack/dlpack.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <tvm/runtime/device_api.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace nnvm {
namespace op {

// TODO(chzhi) Use a config file to store the operator whitelist.
// static constexpr const char* kOpDeviceConfigFile = "op_devive_config.json";

/*
* This class provides criteria for annotating nodes in a graph. It is an
* abstract class that can be derived by other class to implement different
* annotation rules. Rules could be designed either simply based on a vendor
* provided whitelist (e.g. which node is best fitted to which device) or using
* a more intelligent scheme where complicated algorithm is designed to guide
* annotation (future work).
*/
class AnnotationOpSelector {
public:
AnnotationOpSelector() = default;
virtual ~AnnotationOpSelector() {}
// Determine the device that the node should be scheduled to. This is a pure
// virtual function that will be implmented by the children classes for
// different annotation strategies.
virtual DLDeviceType Select(const nnvm::Node* n) const = 0;
};

using AnnotationOpSelectorPtr = std::shared_ptr<AnnotationOpSelector>;

/*!
* \brief This provides a set of properties to annotate nodes.
*/
class AnnotationOpProperty {
using op_device_map_t_ =
std::unordered_map<DLDeviceType, std::unordered_set<std::string>,
tvm::runtime::DLDeviceTypeHash>;

public:
AnnotationOpProperty() = default;

// Create the rule to annotate a graph.
virtual AnnotationOpSelectorPtr CreateAnnotationOpSelector() const = 0;

private:
op_device_map_t_ op_device_map_;
};

using AnnotationOpPropertyPtr = std::shared_ptr<AnnotationOpProperty>;

/*
* This returns the a suitable device for nodes in a graph if the contained
* operator is in the given set.
*/
class ContainOpSelector : public AnnotationOpSelector {
public:
explicit ContainOpSelector(
std::shared_ptr<const std::unordered_set<std::string>> op_names) {
op_names_ = op_names;
}

// TODO(chzhi) Make a config file contain <op name, device_name> pairs
// Currently, we assume the default device is opencl when heterogeneous
// execution is invoked. Users can specify some operators to CPU at the Python
// frontend for annotation. Set the default as the fallback device (CPU) in
// the future, and annotate according to the whitelist.
DLDeviceType Select(const nnvm::Node* n) const final {
if (n->is_variable()) return tvm::runtime::kDLDefaultDevice;

if (op_names_->count(n->op()->name)) return kDLCPU;

// Inference simplification will unpack batch_norm into an array of ops
// starting with "batch_norm". All these operators should be annotated with
// the same device as the bn operator. For example, all unpacked nodes are
// annotated with CPU if batch_norm is specified to be scheduled to CPU.
if (n->attrs.name.rfind("batch_norm") == 0 &&
op_names_->count("batch_norm")) {
return kDLCPU;
}

return tvm::runtime::kDLDefaultDevice;
}

private:
std::shared_ptr<const std::unordered_set<std::string>> op_names_;
};

/*
* This default property finds nodes with operators in a set.
*/
class DefaultAnnotationOpProperty : public AnnotationOpProperty {
public:
explicit DefaultAnnotationOpProperty(
const std::unordered_set<std::string>& op_names)
: op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {
}

virtual AnnotationOpSelectorPtr CreateAnnotationOpSelector() const {
return std::make_shared<ContainOpSelector>(op_names_);
}

private:
std::shared_ptr<const std::unordered_set<std::string>> op_names_;
};

} // namespace op
} // namespace nnvm

#endif // NNVM_GRAPH_ANNOTATE_H_
1 change: 1 addition & 0 deletions nnvm/include/nnvm/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ class Layout {
LOG(FATAL) << "Invalid layout " << layout;
}
}
layout_simplified_.size();
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified_) {
CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
Expand Down
13 changes: 11 additions & 2 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
#ifndef NNVM_NODE_H_
#define NNVM_NODE_H_

#include <tvm/runtime/device_api.h>

#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include <vector>

#include "base.h"
#include "op.h"
#include "c_api.h"
#include "op.h"

namespace nnvm {

Expand Down Expand Up @@ -106,6 +109,12 @@ struct NodeAttrs {
* stateful operators.
*/
std::vector<std::shared_ptr<Symbol> > subgraphs;
/*!
* \brief Device information of the node. It indicates the device that this
* node should be executed. By default, the fallback device is for any
* operator is cpu.
* */
DLDeviceType device{tvm::runtime::kDLDefaultDevice};
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tvm

from . import build_module
from . build_module import build, optimize, build_config
from . build_module import build, optimize, build_config, build_heterogeneous
from . compile_engine import engine, graph_key
from . param_dict import save_param_dict, load_param_dict

Expand Down
Loading