Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
"The virtual device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
"The virtual device/context type where the op copies data to.")
.set_default(0);
}
};
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,37 @@
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
#include <string>
#include <vector>

namespace tvm {
namespace relay {

std::vector<TensorType> FlattenTupleType(const Type& type);
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
Expr ToTupleType(const Type& t, const Array<Expr>& exprs);

/*!
* \brief Options for allocating storage.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
int device_id;
int device_type;

TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(device_id)
.describe(
"The device id on which to allocate memory.");
TVM_ATTR_FIELD(device_type)
.describe(
"The device type on which to allocate memory.");
}
};

/*!
* \brief Options for allocating tensors.
*/
Expand Down
61 changes: 59 additions & 2 deletions python/tvm/relay/op/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""Operators for manipulating low-level memory."""
from __future__ import absolute_import as _abs
from . import _make
Expand All @@ -23,6 +24,9 @@ def invoke_tvm_op(func, inputs, outputs):

Parameters
----------
func : tvm.relay.Expr
The input expr.

inputs : tvm.relay.Expr
A tuple of the inputs to pass to the TVM function.

Expand Down Expand Up @@ -59,7 +63,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
"""
return _make.alloc_tensor(storage, shape, dtype, assert_shape)

def alloc_storage(size, alignment, dtype_hint='float32'):
def alloc_storage(size, alignment, ctx, dtype_hint='float32'):
"""Allocate a piece of tensor storage.

Parameters
Expand All @@ -76,7 +80,7 @@ def alloc_storage(size, alignment, dtype_hint='float32'):
result : tvm.relay.Expr
The alloc_storage expression.
"""
return _make.alloc_storage(size, alignment, dtype_hint)
return _make.alloc_storage(size, alignment, ctx, dtype_hint)

def shape_func(func, inputs, outputs, dependent=False):
"""Invoke the shape function of the passed function.
Expand All @@ -96,3 +100,56 @@ def shape_func(func, inputs, outputs, dependent=False):
The shape function expression.
"""
return _make.shape_func(func, inputs, outputs, dependent)

def flatten_tuple_type(ty):
"""Return a sequence of the types contained in the tuple type in order.

Parameters
----------
ty: tvm.Type
The type to flatten.

Returns
-------
result: List[tvm.Type]
The types in their linear order.
"""
return _make.FlattenTupleType(ty)

def from_tuple_type(ty, expr):
"""Convert an expression with the given type into a sequence of expressions.
Each expression maps to a field of the tuple or nested tuples in linear
order.

Parameters
----------
ty: tvm.Type
The type to unpack.

expr: tvm.relay.Expr
The expression from which to extract each sub-field.

Returns
-------
result: List[tvm.relay.Expr]
The list of sub-expressions.
"""
return _make.FromTupleType(ty, expr)

def to_tuple_type(ty, exprs):
"""Pack the sequence of expressions into the nested tuple type.

Parameters
----------
ty: tvm.Type
The type to pack with.

exprs: tvm.relay.Expr
The expressions to pack back into the nested tuple type.

Returns
-------
result: List[tvm.relay.Expr]
The packed tuple expression.
"""
return _make.ToTupleType(ty, exprs)
Loading