Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ This level enables additional math and transform operators.

tvm.relay.zeros
tvm.relay.nn.leaky_relu
tvm.relay.nn.prelu
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
Expand Down Expand Up @@ -183,6 +184,7 @@ Level 2 Definitions
Level 3 Definitions
-------------------
.. autofunction:: tvm.relay.nn.leaky_relu
.. autofunction:: tvm.relay.nn.prelu
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.trunc
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
};


/*! \brief Attributes for prelu operator */
struct PReluAttrs : public tvm::AttrsNode<PReluAttrs> {
int axis;

TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
.describe("Specify which shape axis the channel is specified.");
}
};


/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class TypeReporterNode : public Node {
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,33 @@ def leaky_relu(data, alpha):
return _make.leaky_relu(data, alpha)


def prelu(data, alpha, axis=1):
"""This operator takes data as input and does Leaky version
of a Rectified Linear Unit.

.. math::

`y = x > 0 ? x : alpha * x`

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

alpha : tvm.relay.Expr
Slope coefficient for the negative half axis.

axis : int, optional
Specify which shape axis the channel is specified.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.prelu(data, alpha, axis)


def pad(data,
pad_width,
pad_value=0.0):
Expand Down
56 changes: 56 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,62 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.add_type_rel("Identity", IdentityRel);


TVM_REGISTER_NODE_TYPE(PReluAttrs);

bool PReluRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

const PReluAttrs* param = attrs.as<PReluAttrs>();
CHECK(param != nullptr);

CHECK(param->axis < static_cast<int>(data->shape.size()))
<< "Wrong axis (" << param->axis << ")value.";

// assign alpha type
Array<IndexExpr> alpha_shape({data->shape[param->axis]});
reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));

// assign output type
reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
return true;
}

// Positional relay function to create prelu operator used by frontend FFI.
Expr MakePRelu(Expr data,
Expr alpha,
int axis) {
auto attrs = make_node<PReluAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.prelu");
return CallNode::make(op, {data, alpha}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.prelu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
});


RELAY_REGISTER_OP("nn.prelu")
.describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the batch.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PReluAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel);


TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
Expand Down
39 changes: 33 additions & 6 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,39 @@ def test_full_like():
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")

def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")

def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
x = relay.var("data", relay.TensorType(data, dtype))
if alpha:
y = relay.var("alpha", relay.TensorType(alpha, dtype))
else:
y = relay.var("alpha", relay.IncompleteType())
z = relay.nn.prelu(x, y, axis=axis)
zz = relay.ir_pass.infer_type(z)
if axis != 1:
assert "axis" in z.astext()
assert zz.checked_type == relay.ty.TensorType(output, dtype)
if not alpha:
axis = axis if axis else 1
alpha_shape = (data[axis],)
assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")

def test_infer_type_prelu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), (c,), 3, (n, h, w, c))
verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), None, 3, (n, h, w, c))
verify_infer_type_prelu((1, 3, 2, 2), (3,), 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), (3,), 3, (1, 2, 2, 3))
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))

if __name__ == "__main__":
test_cast()
Expand All @@ -208,6 +234,7 @@ def test_infer_type_leaky_relu():
test_full()
test_full_like()
test_infer_type_leaky_relu()
test_infer_type_prelu()
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()