Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
Expand Down
23 changes: 21 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,14 @@ bool FullRel(const Array<Type>& types,
return true;
}

Array<Tensor> FullCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) };
}

Expr MakeFull(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
Expand All @@ -696,7 +704,9 @@ RELAY_REGISTER_OP("full")
.set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.")
.set_support_level(3)
.add_type_rel("Full", FullRel);
.add_type_rel("Full", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

bool InitOpRel(const Array<Type>& types,
int num_inputs,
Expand Down Expand Up @@ -777,6 +787,13 @@ bool FullLikeRel(const Array<Type>& types,
return true;
}

Array<Tensor> FullLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::full_like(inputs[0], inputs[1]()) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am you, i will CHECK inputs size before return.
Maybe we should have a convention for how much CHECKing we should do? idk.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input size is verified at the FullLikeRel, I guess again checking will b reduntant. most of new compute func doesnt have this checks.

}

Expr MakeFullLike(Expr data,
Expr fill_value) {
static const Op& op = Op::Get("full_like");
Expand All @@ -797,7 +814,9 @@ and type as the input array.
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("fill_value", "double", "Scalar value to fill.")
.set_support_level(3)
.add_type_rel("FullLike", FullLikeRel);
.add_type_rel("FullLike", FullLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

// where operator
bool WhereRel(const Array<Type>& types,
Expand Down
41 changes: 39 additions & 2 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])),
axis=1)

def test_full():
def test_full_infer_type():
# default settings: match input dtype
x = relay.var("x", relay.TensorType((), "int8"))
y = relay.full(x, ())
Expand All @@ -308,7 +308,22 @@ def test_full():
assert yy.checked_type == relay.TensorType((1, 2), "int8")


def test_full_like():
def test_full():
def verify_full(fill_value, src_shape, dtype):
x = relay.var("x", relay.scalar_type(dtype))
z = relay.full(x, src_shape, dtype)
func = relay.Function([x], z)
ref_res = np.full(src_shape, fill_value)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(fill_value)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full(4, (1, 3, 4, 4), "int32")
verify_full(4.0, (1, 4), "float32")


def test_full_like_infer_type():
# concrete shape
base = relay.var("base", relay.TensorType((1, 2, 3), "float32"))
fill = relay.var("fill", relay.TensorType((), "float32"))
Expand All @@ -324,6 +339,26 @@ def test_full_like():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")


def test_full_like():
def verify_full_like(base, fill_value, dtype):
x_data = np.random.uniform(low=-1, high=1, size=base).astype(dtype)
x = relay.var("x", relay.TensorType(base, dtype))
y = relay.var("y", relay.scalar_type(dtype))
z = relay.full_like(x, y)

func = relay.Function([x, y], z)
ref_res = np.full_like(x_data, fill_value)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, fill_value)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_full_like((1, 3, 4, 4), 4, "int32")
verify_full_like((1, 1), 44.0, "float32")


def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
Expand Down Expand Up @@ -412,7 +447,9 @@ def test_infer_type_prelu():
test_reshape_like()
test_take_infer_type()
test_take()
test_full_infer_type()
test_full()
test_full_like_infer_type()
test_full_like()
test_infer_type_leaky_relu()
test_infer_type_prelu()
Expand Down