diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 8726db55f8c1..c1624028fe68 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -11,6 +11,8 @@ _reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("reshape", schedule_injective) _reg.register_schedule("reshape_like", schedule_injective) +_reg.register_schedule("full", schedule_injective) +_reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("cast", schedule_broadcast) _reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 689c9c9bb8d7..53741e666f38 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -673,6 +673,14 @@ bool FullRel(const Array& types, return true; } +Array FullCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* out_ttype = out_type.as(); + return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) }; +} + Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { @@ -696,7 +704,9 @@ RELAY_REGISTER_OP("full") .set_num_inputs(1) .add_argument("fill_value", "double", "The value to fill.") .set_support_level(3) -.add_type_rel("Full", FullRel); +.add_type_rel("Full", FullRel) +.set_attr("FTVMCompute", FullCompute) +.set_attr("TOpPattern", kElemWise); bool InitOpRel(const Array& types, int num_inputs, @@ -777,6 +787,13 @@ bool FullLikeRel(const Array& types, return true; } +Array FullLikeCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return { topi::full_like(inputs[0], inputs[1]()) }; +} + Expr MakeFullLike(Expr data, Expr fill_value) { static const Op& op = Op::Get("full_like"); @@ -797,7 +814,9 @@ and type as the input array. .add_argument("data", "Tensor", "The input tensor.") .add_argument("fill_value", "double", "Scalar value to fill.") .set_support_level(3) -.add_type_rel("FullLike", FullLikeRel); +.add_type_rel("FullLike", FullLikeRel) +.set_attr("FTVMCompute", FullLikeCompute) +.set_attr("TOpPattern", kElemWise); // where operator bool WhereRel(const Array& types, diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 99d7b4f95de5..617b532a6a1f 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -293,7 +293,7 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])), axis=1) -def test_full(): +def test_full_infer_type(): # default settings: match input dtype x = relay.var("x", relay.TensorType((), "int8")) y = relay.full(x, ()) @@ -308,7 +308,22 @@ def test_full(): assert yy.checked_type == relay.TensorType((1, 2), "int8") -def test_full_like(): +def test_full(): + def verify_full(fill_value, src_shape, dtype): + x = relay.var("x", relay.scalar_type(dtype)) + z = relay.full(x, src_shape, dtype) + func = relay.Function([x], z) + ref_res = np.full(src_shape, fill_value) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(fill_value) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_full(4, (1, 3, 4, 4), "int32") + verify_full(4.0, (1, 4), "float32") + + +def test_full_like_infer_type(): # concrete shape base = relay.var("base", relay.TensorType((1, 2, 3), "float32")) fill = relay.var("fill", relay.TensorType((), "float32")) @@ -324,6 +339,26 @@ def test_full_like(): yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") + +def test_full_like(): + def verify_full_like(base, fill_value, dtype): + x_data = np.random.uniform(low=-1, high=1, size=base).astype(dtype) + x = relay.var("x", relay.TensorType(base, dtype)) + y = relay.var("y", relay.scalar_type(dtype)) + z = relay.full_like(x, y) + + func = relay.Function([x, y], z) + ref_res = np.full_like(x_data, fill_value) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, fill_value) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_full_like((1, 3, 4, 4), 4, "int32") + verify_full_like((1, 1), 44.0, "float32") + + def test_infer_type_leaky_relu(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) @@ -412,7 +447,9 @@ def test_infer_type_prelu(): test_reshape_like() test_take_infer_type() test_take() + test_full_infer_type() test_full() + test_full_like_infer_type() test_full_like() test_infer_type_leaky_relu() test_infer_type_prelu()