Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
/*! \brief Attributes used in dilate operator */
struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
Array<IndexExpr> strides;
double dilation_value;

TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1}))
.describe("Dilation stride on each dimension, 1 means no dilation.");
TVM_ATTR_FIELD(dilation_value).set_default(0.0).describe("Value used to dilate the input.");
}
};

Expand Down
10 changes: 6 additions & 4 deletions include/tvm/topi/nn/dilate.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,20 @@ PrimExpr all(Array<PrimExpr> args) {
}

/*!
* \brief Dilate data with zeros
* \brief Dilate data with given dilation value (0 by default).
*
* \param x The input tensor, this can have any number of
* dimensions and any layout.
* \param strides Dilation stride for each dimension. Stride 1
* means no dilation.
* \param dilation_value Value used to dilate the input.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The output tensor.
*/
inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name = "tensor",
std::string tag = kInjective) {
inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, double dilation_value,
std::string name = "tensor", std::string tag = kInjective) {
auto n = x->shape.size();
CHECK_EQ(n, strides.size()) << "strides size (" << strides.size()
<< ") must match dimension of x (" << n << ")";
Expand All @@ -94,7 +95,8 @@ inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name
}
if (not_zero.size() > 0) {
auto all_not_zero = all(not_zero);
return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0));
return tvm::if_then_else(all_not_zero, x(index_tuple),
make_const(x->dtype, dilation_value));
}
return x(index_tuple);
},
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype):
# dilate
@reg.register_compute("nn.dilate")
def compute_dilate(attrs, inputs, out_dtype):
return [topi.nn.dilate(inputs[0], attrs.strides)]
return [topi.nn.dilate(inputs[0], attrs.strides, attrs.dilation_value)]


reg.register_broadcast_schedule("nn.dilate")
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,23 +1549,26 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
return _make.pad(data, pad_width, pad_value, pad_mode)


def dilate(data, strides):
"""Dilate data with zeros.
def dilate(data, strides, dilation_value=0.0):
"""Dilate data with given dilation value (0 by default).

Parameters
----------
data : tvm.relay.Expr
n-D, can be any layout.

strides : <tuple of <int>
strides : tuple of <int>
Dilation stride on each dimension, 1 means no dilation.

dilation_value : int/float, optional
Value used to dilate the input.

Returns
-------
Output : tvm.relay.Expr
The computed result
"""
return _make.dilate(data, strides)
return _make.dilate(data, strides, dilation_value)


def mirror_pad(data, pad_width, mode="SYMMETRIC"):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/nn/dilate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@


@te.tag_scope(tag=tag.INJECTIVE + ",dilate")
def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros.
def dilate(data, strides, dilation_value=0.0, name="DilatedInput"):
"""Dilate data with given dilation value (0 by default).

Parameters
----------
Expand All @@ -34,6 +34,9 @@ def dilate(data, strides, name="DilatedInput"):
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.

dilation_value : int/float, optional
Value used to dilate the input.

name : str, optional
The name prefix operators generated

Expand Down Expand Up @@ -62,7 +65,7 @@ def _dilate(*indices):
if not_zero:
not_zero = tvm.tir.all(*not_zero)
return tvm.tir.if_then_else(
not_zero, data(*index_tuple), tvm.tir.const(0.0, data.dtype)
not_zero, data(*index_tuple), tvm.tir.const(dilation_value, data.dtype)
)
return data(*index_tuple)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/topi/testing/dilate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def dilate_python(input_np, strides):
def dilate_python(input_np, strides, dilation_value=0.0):
"""Dilate operation.

Parameters
Expand All @@ -30,6 +30,9 @@ def dilate_python(input_np, strides):
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.

dilation_value : int/float, optional
Value used to dilate the input.

Returns
-------
output_np : numpy.ndarray
Expand All @@ -45,7 +48,8 @@ def dilate_python(input_np, strides):
for i in range(n):
output_size += ((input_np.shape[i] - 1) * strides[i] + 1,)
no_zero += ((range(0, output_size[i], strides[i])),)
output_np = np.zeros(shape=output_size)
output_np = np.ones(shape=output_size)
output_np = dilation_value * output_np
output_np[np.ix_(*no_zero)] = input_np

return output_np
5 changes: 3 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,10 @@ bool DilateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

// Positional relay function to create dilate operator used by frontend FFI.
Expr MakeDilate(Expr data, Array<IndexExpr> strides) {
Expr MakeDilate(Expr data, Array<IndexExpr> strides, double dilation_value = 0.0) {
auto attrs = make_object<DilateAttrs>();
attrs->strides = std::move(strides);
attrs->dilation_value = std::move(dilation_value);
static const Op& op = Op::Get("nn.dilate");
return Call(op, {data}, Attrs(attrs), {});
}
Expand All @@ -972,7 +973,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate);

RELAY_REGISTER_OP("nn.dilate")
.describe(R"code(
Dilate data with zeros.
Dilate data with given dilation value (0 by default).
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("x", "1D Tensor", "Data to dilate.")
Expand Down
2 changes: 1 addition & 1 deletion src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValu

/* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::dilate(args[0], args[1]);
*rv = nn::dilate(args[0], args[1], args[2]);
});

/* Ops from nn/flatten.h */
Expand Down
13 changes: 10 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,18 +740,24 @@ def test_any_pad():
verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))


def verify_any_dilate(data_shape, strides, static_data_shape):
def verify_any_dilate(data_shape, strides, static_data_shape, dilation_value=None):
assert len(data_shape) == len(strides)
mod = tvm.IRModule()
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
y = relay.nn.dilate(data, strides)
if dilation_value is None:
y = relay.nn.dilate(data, strides)
else:
y = relay.nn.dilate(data, strides, dilation_value)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_shape = tuple(
(static_data_shape[i] - 1) * strides[i] + 1 for i in range(len(static_data_shape))
)
ref_out = np.zeros(shape=ref_shape, dtype=dtype)
if dilation_value is None:
dilation_value = 0.0
ref_out = np.ones(shape=ref_shape, dtype=dtype)
ref_out = dilation_value * ref_out
ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np
check_result([data_np], mod, ref_out)

Expand All @@ -766,6 +772,7 @@ def test_any_dilate():
verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3))
verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3))
verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4))
verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4), 1.0)


def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
Expand Down
13 changes: 10 additions & 3 deletions tests/python/topi/python/test_topi_dilate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ def test_dilate():
target = "llvm"
ctx = tvm.cpu(0)

def _test_dilate(input_size, strides):
def _test_dilate(input_size, strides, dilation_value=None):
Input = te.placeholder((input_size))
Output = topi.nn.dilate(Input, strides)
if dilation_value is None:
Output = topi.nn.dilate(Input, strides)
else:
Output = topi.nn.dilate(Input, strides, dilation_value)
schedule = te.create_schedule(Output.op)
input_np = np.random.uniform(size=input_size).astype(Input.dtype)
output_np = tvm.topi.testing.dilate_python(input_np, strides)
if dilation_value is None:
output_np = tvm.topi.testing.dilate_python(input_np, strides)
else:
output_np = tvm.topi.testing.dilate_python(input_np, strides, dilation_value)
input_tvm = tvm.nd.array(input_np, ctx=ctx)
output_size = topi.util.get_const_tuple(Output.shape)
output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx)
Expand All @@ -47,6 +53,7 @@ def _test_dilate(input_size, strides):
_test_dilate((1, 32, 32, 3, 3), (2, 2, 2, 2, 2))
_test_dilate((1, 32, 32, 32, 3, 3), (1, 1, 1, 2, 2, 2))
_test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1))
_test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1), 1.0)


if __name__ == "__main__":
Expand Down