diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 536e4145db29..bdb75c8e9f34 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -442,6 +442,16 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in dilate operator */ +struct DilateAttrs : public tvm::AttrsNode { + Array strides; + + TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") { + TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + .describe("Dilation stride on each dimension, 1 means no dilation."); + } +}; + /*! \brief Attributes used in 1D transposed convolution operator */ struct Conv1DTransposeAttrs : public tvm::AttrsNode { IndexExpr channels; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 51e712869cf4..e222acbd2b1a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -458,6 +458,15 @@ def compute_cross_entropy(attrs, inputs, out_dtype): reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) +# dilate +@reg.register_compute("nn.dilate") +def compute_dilate(attrs, inputs, out_dtype): + return [topi.nn.dilate(inputs[0], attrs.strides)] + +reg.register_broadcast_schedule("nn.dilate") +reg.register_pattern("nn.dilate", OpPattern.INJECTIVE) + + # cross_entropy_with_logits @reg.register_compute("nn.cross_entropy_with_logits") def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): @@ -653,6 +662,21 @@ def pad_shape_func(attrs, inputs, _): pad_width.append(get_const_tuple(pair)) return [_pad_shape_func(inputs[0], convert(pad_width))] +@script +def _dilate_shape_func(data_shape, strides): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0]): + out[i] = (data_shape[i] - 1) * strides[i] + 1 + + return out + +@reg.register_shape_func("nn.dilate", False) +def dilate_shape_func(attrs, inputs, _): + """ + Shape function for dilate op. + """ + return [_dilate_shape_func(inputs[0], convert(attrs.strides))] + reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index a126e8dcba94..745ab3bbf1b4 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1347,6 +1347,25 @@ def pad(data, return _make.pad(data, pad_width, pad_value, pad_mode) +def dilate(data, strides): + """Dilate data with zeros. + + Parameters + ---------- + data : tvm.relay.Expr + n-D, can be any layout. + + strides : + Dilation stride on each dimension, 1 means no dilation. + + Returns + ------- + Output : tvm.relay.Expr + The computed result + """ + return _make.dilate(data, strides) + + def mirror_pad(data, pad_width, mode="SYMMETRIC"): diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index a47be7673830..a1c73ef41ba5 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -350,6 +350,11 @@ class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" +@tvm._ffi.register_object("relay.attrs.DilateAttrs") +class DilateAttrs(Attrs): + """Attributes used in dilate operators""" + + @tvm._ffi.register_object("relay.attrs.SubPixelAttrs") class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index b9ba74f9e95d..fb1351aa19c2 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -961,6 +961,54 @@ Do log on the data - do not accept logits. .add_type_rel("CrossEntropy", CrossEntropyRel); +// relay.nn.dilate +TVM_REGISTER_NODE_TYPE(DilateAttrs); + +bool DilateRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* x = types[0].as(); + const DilateAttrs* param = attrs.as(); + if (x == nullptr) return false; + CHECK_EQ(x->shape.size(), param->strides.size()); + + std::vector oshape; + for (size_t i = 0; i < param->strides.size(); ++i) { + if (!x->shape[i].as()) { + oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1); + } else { + oshape.push_back(x->shape[i]); + } + } + + reporter->Assign(types[1], TensorType(Array(oshape), x->dtype)); + return true; +} + +// Positional relay function to create dilate operator used by frontend FFI. +Expr MakeDilate(Expr data, Array strides) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + static const Op& op = Op::Get("nn.dilate"); + return Call(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate") +.set_body_typed(MakeDilate); + + +RELAY_REGISTER_OP("nn.dilate") +.describe(R"code( +Dilate data with zeros. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("x", "1D Tensor", "Data to dilate.") +.set_support_level(10) +.add_type_rel("Dilate", DilateRel); + // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy_with_logits"); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index aa81e3113b7f..6ce59bbf1c36 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -508,6 +508,34 @@ def test_any_pad(): verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3)) verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1)) +def verify_any_dilate(data_shape, strides, static_data_shape): + assert len(data_shape) == len(strides) + mod = tvm.IRModule() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.nn.dilate(data, strides) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1 + for i in range(len(static_data_shape))) + ref_out = np.zeros(shape=ref_shape, dtype=dtype) + ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_dilate(): + verify_any_dilate(any_dims(1), (1,), (1,)) + verify_any_dilate(any_dims(1), (1,), (5,)) + verify_any_dilate(any_dims(1), (5,), (5,)) + verify_any_dilate(any_dims(3), (1, 1, 1), (1, 2, 3)) + verify_any_dilate(any_dims(3), (1, 1, 2), (1, 2, 3)) + verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3)) + verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3)) + verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4)) + def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32"