Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,37 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Attributes for the "on_device" operator.
*
* The relay call
* \code
* on_device(expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
*
* See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level
* companion.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which argument expression should be evaluated. */
int device_type = kInvalidDeviceType;
/*!
* \brief If true, the result device must also be \p device_type and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
*/
bool is_fixed = false;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe("The virutal device/context type that an expression is annotated with.")
.describe("The type of the virtual device which should hold the expression result.")
.set_default(0);
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
.set_default(false);
}
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
// TODO(mbs): Should be TargetDevice.
int dst_dev_type;
int src_dev_type;

Expand Down
66 changes: 66 additions & 0 deletions include/tvm/relay/attrs/function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/attrs/function.h
* \brief Attributes for Relay Functions which don't make sense on PrimFuncs.
*/
#ifndef TVM_RELAY_ATTRS_FUNCTION_H_
#define TVM_RELAY_ATTRS_FUNCTION_H_

namespace tvm {
namespace relay {
/*!
* \brief Attributes for Relay function definitions which capture the devices for the
* function parameters and result.
*
* See also OnDeviceAttrs in include/tvm/relay/attrs/annotation.h for the companion "on_device"
* call attributes.
*/
struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
/*! \brief Device type on which each of the function's arguments already resides. */
Array<Integer> param_device_types;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which function body should be evaluated. */
int result_device_type = kInvalidDeviceType;

TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") {
TVM_ATTR_FIELD(param_device_types)
.describe("The type of the virtual device which holds each function parameters.");
TVM_ATTR_FIELD(result_device_type)
.describe("The type of the virtual device which will hold the function's result.")
.set_default(0);
}
};

namespace attr {

/*!
* \brief Device annotations for function parameters and results.
*
* Type: FunctionOnDeviceAttrs
*/
constexpr static const char* kFunctionAttrsKey = "on_device";

} // namespace attr

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_ATTRS_FUNCTION_H_
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ TVM_DLL Pass RelayToTIRTargetHook();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param target_host The target used by the host for compliation.
* \param targets The device type and target pairs for compliation.
* \param target_host The target used by the host for compilation.
* \param targets The device type and target pairs for compilation.
*
* \return The pass.
*/
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array, container representing a contigious sequence of ObjectRefs.
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
* Array implements in-place copy-on-write semantics.
*
Expand Down
26 changes: 16 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,19 @@
#include <vector>

namespace tvm {
namespace runtime {

typedef DLDevice Device;
// alias DLDevice
using Device = DLDevice;

// A 'null' device type, does not correspond to any DLDeviceType enum.
// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
// as a singleton target map indexed by the invalid DLDeviceType '0'.
constexpr DLDeviceType kNullDeviceType = static_cast<DLDeviceType>(0);

// An 'invalid' device type, does not correspond to any DLDeviceType enum.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these not be added as valid enum variants to dlpack.h ? There's a risk here that dlpack will change and it'll become incompatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. Note that these special values are endemic in the code (often stored in naked ints) so I'm at least not introducing a new convention. I think the argument for this approach is the enum values should always be valid at runtime and these invalid-but-distinguished values are just a compile time concept.

I this as a stepping stone to planning with hybrid targets/devices rather than just device or device types, in which case these distinguished values can go away.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Roger roger.

constexpr DLDeviceType kInvalidDeviceType = static_cast<DLDeviceType>(-1);

namespace runtime {

/*!
* \brief Managed NDArray.
Expand Down Expand Up @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}

} // namespace runtime

// alias Device
using tvm::runtime::Device;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::runtime::Device> {
std::size_t operator()(const tvm::runtime::Device& dev) const {
struct hash<tvm::Device> {
std::size_t operator()(const tvm::Device& dev) const {
return ((dev.device_id << 8) | dev.device_type);
}
};

template <>
struct equal_to<tvm::runtime::Device> {
bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const {
struct equal_to<tvm::Device> {
bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const {
return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id);
}
};
Expand Down
56 changes: 43 additions & 13 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,62 @@
from .. import op as reg


def on_device(data, device):
"""Annotate an expression with a certain device type.
def _device_to_int(device):
if isinstance(device, _Device):
return device.device_type
if isinstance(device, str):
return _nd.device(device).device_type
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))


def on_device(data, device, is_fixed=False):
"""Annotates an expression with the device type on which its result should be stored.

Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.

device : Union[:py:class:`Device`, str]
The device type to annotate.
The device to annotate with. Only the device's type is significant.

is_fixed : bool
If false (the default), a device_copy
If true, the annotation does not imply a device_copy may be inserted to
reconcile the device of the data argument with the device for the context of the
annotated expression.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _Device):
device = device.device_type
elif isinstance(device, str):
device = _nd.device(device).device_type
else:
raise ValueError(
"device is expected to be the type of Device or "
"str, but received %s" % (type(device))
)
return _make.on_device(data, device)
return _make.on_device(data, _device_to_int(device), is_fixed)


def function_on_device(function, param_devices, result_device):
"""Annotates a Relay function with the device types on which its parameters and result should
be stored.

Parameters
----------
function : tvm.relay.Function
The function to be annotated.

param_devices : Array[Union[:py:class:`Device`, str]]
The devices for each parameter. Only the device types are significant.

result_device: Union[:py:class:`Device`, str]
The device for the function result. Only the device type is significant.

Returns
-------
result : tvm.rleay.Function
The annotated function.
"""
return _make.function_on_device(
function, [_device_to_int(d) for d in param_devices], _device_to_int(result_device)
)


def stop_fusion(data):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def MergeCompilerRegions():

def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.

Expand Down
8 changes: 4 additions & 4 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def hexagon(cpu_ver="v66", **kwargs):

# LLVM target string
def create_llvm_target(cpu_ver, config):
""" Create LLVM target string. """
"""Create LLVM target string."""

target = " -mtriple=hexagon"
mcpu = " -mcpu=hexagon" + cpu_ver
Expand All @@ -547,7 +547,7 @@ def create_target_features(config):

# Simulator options string
def create_sim_options(cpu_ver, config):
""" Create simulator option string. """
"""Create simulator option string."""

def validate_hvx_length(codegen_hvx, sim_options):
if sim_options and "--hvx_length" in sim_options:
Expand Down Expand Up @@ -606,7 +606,7 @@ def validate_hvx_length(codegen_hvx, sim_options):

# LLVM options string
def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument
""" Create LLVM options string. """
"""Create LLVM options string."""

llvm_options = config["llvm_options"]

Expand All @@ -620,7 +620,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument

# TVM target attributes string
def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument
""" Create TVM target features string. """
"""Create TVM target features string."""

features = {
"link_params": "link-params",
Expand Down
7 changes: 5 additions & 2 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/*!
* \file src/node/structural_equal.cc
*/
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
Expand Down Expand Up @@ -119,8 +120,10 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n"
<< "lhs = " << lhs << "\nrhs = " << rhs;
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
<< PrettyPrint(lhs) << std::endl
<< "and rhs:" << std::endl
<< PrettyPrint(rhs);
}
return result;
}
Expand Down
8 changes: 3 additions & 5 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,10 @@ class LowerTensorExprMutator : public ExprMutator {
}

// Non-External Relay Function
DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n"
<< PrettyPrint(func);
VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func);
CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);
DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";
VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";

// Collect all the lowered functions produced for this primitive function.
Map<GlobalVar, tir::PrimFunc> prim_fns;
Expand All @@ -452,8 +451,7 @@ class LowerTensorExprMutator : public ExprMutator {
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
all_prim_fn_vars.push_back(prim_fn.first);
DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first)
<< "'";
VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'";
}

// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ struct PrimitiveInliner : ExprMutator {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false);
VLOG(1) << "Before inlining primitives: " << global << std::endl << PrettyPrint(func);

func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false);
VLOG(1) << "After inlining primitives: " << global << std::endl << PrettyPrint(func);
}
}
return module_;
Expand Down
Loading