Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c1ea9b9
Dynamic BroadcastTo
electriclilies Jul 7, 2020
5e1b5c5
fixed lint!
electriclilies Jul 7, 2020
8e115b9
add test_one_hot() back
Jul 7, 2020
988fe4c
add one_hot registration back
Jul 7, 2020
2e35508
Dynamic BroadcastTo
electriclilies Jul 7, 2020
0852f42
fixed lint!
electriclilies Jul 7, 2020
6977134
add one_hot registration back
Jul 7, 2020
9e11e24
fixed lint.. again
electriclilies Jul 7, 2020
40cea20
fixed lint
electriclilies Jul 7, 2020
42831f7
lint
electriclilies Jul 7, 2020
1cc1b8f
responding to comments
electriclilies Jul 7, 2020
81b7d5c
skipping cuda in dynamic test
electriclilies Jul 7, 2020
4fd8a84
skipping cuda in dynamic test
electriclilies Jul 7, 2020
da104ac
fixed i386 test and GPU test
electriclilies Jul 7, 2020
e47bca4
lint
electriclilies Jul 7, 2020
dd1ec39
starting ones and zeros
electriclilies Jul 8, 2020
c1b0303
fixed dynamic ones and zeros, wrote dyn ones and zeros test
electriclilies Jul 8, 2020
2c33802
added static version of zeros, ones and added a check for size of typ…
electriclilies Jul 8, 2020
79d7e8a
added dynamic to static pass for zeros and ones, dynamic test and dyn…
electriclilies Jul 8, 2020
f6d7765
removed op_str in dyn to static pass test
electriclilies Jul 8, 2020
2192ff9
fixed lint
electriclilies Jul 9, 2020
3f71e4d
fix lint hopefully
electriclilies Jul 9, 2020
1272a71
removed import const
electriclilies Jul 9, 2020
7ad825e
removed import that was actually used
Jul 9, 2020
9ae620e
copy all attributes from broadcast_to, ones, zeros, full
Jul 9, 2020
b9c8767
responding to comments
Jul 13, 2020
53bab6c
fixed build error
Jul 13, 2020
4c0129c
finishing rebase
Jul 14, 2020
1fc7f7f
fix lint
Jul 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def __call__(self, args, attrs, type_args):
attrs = {}
if self.operator in (op.strided_slice,):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type):
assert len(inputs) == 1
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)]

register_broadcast_schedule("zeros")
Expand All @@ -109,7 +109,7 @@ def zeros_like_compute(attrs, inputs, output_type):
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type):
assert len(inputs) == 1
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)]

register_broadcast_schedule("ones")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

from . import _algorithm
from . import _transform
from . import _tensor
46 changes: 46 additions & 0 deletions python/tvm/relay/op/dyn/_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration for dynamic ops"""

import topi

from ..op import register_shape_func, register_compute
from ..op import register_broadcast_schedule
from ..op import register_pattern, OpPattern
from .._tensor import full_shape_func, no_data_full_shape_func

# ones
@register_compute("dyn.ones")
def ones_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 1.0)]

register_broadcast_schedule("dyn.ones")
register_pattern("dyn.ones", OpPattern.ELEMWISE)

@register_compute("dyn.zeros")
def zeros_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 0.0)]

register_broadcast_schedule("dyn.zeros")
register_pattern("dyn.zeros", OpPattern.ELEMWISE)

register_shape_func("dyn.broadcast_to", True, full_shape_func)
register_shape_func("dyn.ones", True, no_data_full_shape_func)
register_shape_func("dyn.zeros", True, no_data_full_shape_func)
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.te.hybrid import script
from .. import op as _reg

_reg.register_broadcast_schedule("dyn.broadcast_to")
_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")

Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tvm.runtime import TVMContext as _TVMContext

from . import _make
from ..expr import Tuple, const
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr


# We create a wrapper function for each operator in the
Expand Down Expand Up @@ -939,8 +940,12 @@ def zeros(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.zeros(shape, dtype)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.zeros(shape, dtype)


Expand Down Expand Up @@ -976,8 +981,12 @@ def ones(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.ones(shape, dtype)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.ones(shape, dtype)


Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,12 @@ def broadcast_to(data, shape):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.broadcast_to(data, shape)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.broadcast_to(data, shape)

def broadcast_to_like(data, broadcast_type):
Expand Down
109 changes: 109 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@
*/
#include "transform.h"

#include <topi/broadcast.h>
#include <topi/transform.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>

#include <utility>
#include <vector>

namespace tvm {
namespace relay {
namespace dyn {

/* relay.dyn.reshape */

bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, newshape, result]
Expand Down Expand Up @@ -195,6 +198,112 @@ RELAY_REGISTER_OP("dyn.tile")
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// broadcast_to operator
bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [data_type, broadcast_shape_type, ret_type]
CHECK_EQ(types.size(), 3);

const auto* target_shape = types[1].as<TensorTypeNode>();
DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
// rank must be static
const IntImmNode* rank = target_shape->shape[0].as<IntImmNode>();
CHECK(rank) << "Target shape must have static rank"; // rank must be static even in dyn pass
// could add support for dyn rank in futures

std::vector<IndexExpr> oshape;
for (int i = 0; i < rank->value; ++i) {
oshape.push_back(Any());
}

reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

Expr MakeBroadCastTo(Expr data, Expr shape) {
static const Op& op = Op::Get("dyn.broadcast_to");
auto attrs = make_object<InitOpAttrs>();
return Call(op, {data, shape}, Attrs(attrs), {});
}

Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::broadcast_to(inputs[0], out_ttype->shape)};
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.broadcast_to").set_body_typed(MakeBroadCastTo);

RELAY_REGISTER_OP("dyn.broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(4)
.add_type_rel("DynamicBroadCastTo", BroadCastToRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

// zeros and ones operator
bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [zeros_shape, ret_type]
CHECK_EQ(types.size(), 2);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_shape = types[0].as<TensorTypeNode>();
DataType out_dtype = param->dtype;

const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
CHECK(shape_shape) << "Parameter shape must have static rank";

std::vector<IndexExpr> oshape;
for (int i = 0; i < shape_shape->value; ++i) {
oshape.push_back(Any());
}

reporter->Assign(types[1], TensorType(oshape, out_dtype));
return true;
}

Expr MakeZeros(Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("dyn.zeros");
return Call(op, {shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.zeros").set_body_typed(MakeZeros);

RELAY_REGISTER_OP("dyn.zeros")
.describe(R"code(Fill array with zeros.

)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(1)
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("DynamicInitOp", InitOpRel);

Expr MakeOnes(Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("dyn.ones");
return Call(op, {shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.ones").set_body_typed(MakeOnes);

RELAY_REGISTER_OP("dyn.ones")
.describe(R"code(Fill array with ones.

)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(1)
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("DynamicInitOp", InitOpRel);

} // namespace dyn
} // namespace relay
} // namespace tvm
6 changes: 3 additions & 3 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
namespace tvm {
namespace relay {

Expr MakeBroadCastTo(Expr data, Expr shape);
Expr MakeBroadCastTo(Expr data, Array<Integer> shape);

Expr MakeCast(Expr data, DataType dtype);

Expand All @@ -52,7 +52,7 @@ Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);

Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);

Expr MakeOnes(Expr shape, DataType dtype);
Expr MakeOnes(Array<Integer> shape, DataType dtype);

Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode);

Expand All @@ -76,7 +76,7 @@ Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataT

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude);

Expr MakeZeros(Expr shape, DataType dtype);
Expr MakeZeros(Array<Integer> shape, DataType dtype);

} // namespace relay
} // namespace tvm
Expand Down
Loading