Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,16 @@ struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
}
}; // struct UniqueAttrs

/*! \brief Attributes used in segment_max, segment_min,
segment_mean, segment_sum, segment_prod operator */
struct SegmentAttrs : public tvm::AttrsNode<SegmentAttrs> {
int num_segments;

TVM_DECLARE_ATTRS(SegmentAttrs, "relay.attrs.SegmentAttrs") {
TVM_ATTR_FIELD(num_segments).set_default(0).describe("The maximum of segment_ids.");
}
}; // struct SegmentAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
43 changes: 43 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,22 @@ def _impl(inputs, attr, params, mod):
return _impl


def _unsorted_segment(name):
def _impl(inputs, attr, params, mod):
# op description: https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max
try:
num_segments = _infer_value(inputs[2], params).asnumpy().tolist()
except Exception:
raise tvm.error.OpAttributeInvalid("Can't find num_segments.")
return AttrCvt(
op_name="segment_" + name,
ignores=["Tdim", "Tidx", "Tindices", "Tnumsegments"],
extras={"num_segments": num_segments},
)([inputs[0], inputs[1]], attr)

return _impl


def _crop_and_resize():
def _impl(inputs, attr, params, mod):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
Expand Down Expand Up @@ -2617,6 +2633,23 @@ def _impl(inputs, attr, params, mod):
return _impl


def _segment(opname):
def _impl(inputs, attr, params, mod):
# op description: https://www.tensorflow.org/api_docs/python/tf/math/segment_max
try:
segment_ids = _infer_value(inputs[1], params)
except Exception:
raise tvm.error.OpAttributeInvalid("Can't get value of segment_ids.")

num_out = segment_ids.asnumpy().max() + 1
out = AttrCvt(op_name=opname, ignores=["T", "Tindices"], extras={"num_segments": num_out})(
inputs, attr
)
return out

return _impl


def _size():
def _impl(inputs, attr, params, mod):
new_attr = attr
Expand Down Expand Up @@ -2864,6 +2897,11 @@ def _impl(inputs, attr, params, mod):
"SelectV2": _where(),
"Selu": _selu(),
"Shape": _shape(),
"SegmentMax": _segment("segment_max"),
"SegmentMean": _segment("segment_mean"),
"SegmentMin": _segment("segment_min"),
"SegmentProd": _segment("segment_prod"),
"SegmentSum": _segment("segment_sum"),
"Sigmoid": AttrCvt("sigmoid"),
"Sign": AttrCvt("sign"),
"Sin": AttrCvt("sin"),
Expand Down Expand Up @@ -2915,6 +2953,11 @@ def _impl(inputs, attr, params, mod):
"UniqueWithCounts": _unique(True),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"UnsortedSegmentMax": _unsorted_segment("max"),
"UnsortedSegmentMin": _unsorted_segment("min"),
"UnsortedSegmentMean": _unsorted_segment("mean"),
"UnsortedSegmentProd": _unsorted_segment("prod"),
"UnsortedSegmentSum": _unsorted_segment("sum"),
"Where": _where(),
"ZerosLike": AttrCvt("zeros_like"),
}
Expand Down
23 changes: 22 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm import topi
from tvm.runtime import convert

from .op import register_compute, register_shape_func
from . import strategy
from .op import register_compute, register_shape_func, register_strategy
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern

Expand Down Expand Up @@ -283,3 +284,23 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("sigmoid", False, elemwise_shape_func)
register_shape_func("tanh", False, elemwise_shape_func)
register_shape_func("logical_not", False, elemwise_shape_func)

# segment_max
register_strategy("segment_max", strategy.segment_max_strategy)
register_pattern("segment_max", OpPattern.OPAQUE)

# segment_min
register_strategy("segment_min", strategy.segment_min_strategy)
register_pattern("segment_min", OpPattern.OPAQUE)

# segment_mean
register_strategy("segment_mean", strategy.segment_mean_strategy)
register_pattern("segment_mean", OpPattern.OPAQUE)

# segment_sum
register_strategy("segment_sum", strategy.segment_sum_strategy)
register_pattern("segment_sum", OpPattern.OPAQUE)

# segment_prod
register_strategy("segment_prod", strategy.segment_prod_strategy)
register_pattern("segment_prod", OpPattern.OPAQUE)
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ class ProposalAttrs(Attrs):
"""Attributes used in proposal operators"""


@tvm._ffi.register_object("relay.attrs.SegmentAttrs")
class SegmentAttrs(Attrs):
"""Attributes used in segment operators"""


@tvm._ffi.register_object("relay.attrs.MaxPool2DAttrs")
class MaxPool2DAttrs(Attrs):
"""Attributes used in max_pool2d operators"""
Expand Down
115 changes: 115 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,29 @@ def uniform_strategy(attrs, inputs, out_type, target):
return strategy


# segment_max
def wrap_compute_segment_max(topi_compute):
"""wrap segment_max topi compute"""

def _compute_segment_max(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "max")]

return _compute_segment_max


@override_native_generic_func("segment_max_strategy")
def segment_max_strategy(attrs, inputs, out_type, target):
"""segment_max generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_max(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_max.generic",
)
return strategy


def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""

Expand All @@ -1561,6 +1584,29 @@ def cumsum_strategy(attrs, inputs, out_type, target):
return strategy


# segment_min
def wrap_compute_segment_min(topi_compute):
"""wrap segment_min topi compute"""

def _compute_segment_min(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "min")]

return _compute_segment_min


@override_native_generic_func("segment_min_strategy")
def segment_min_strategy(attrs, inputs, out_type, target):
"""segment_min generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_min(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_min.generic",
)
return strategy


@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
Expand All @@ -1573,6 +1619,29 @@ def cumprod_strategy(attrs, inputs, out_type, target):
return strategy


# segment_mean
def wrap_compute_segment_mean(topi_compute):
"""wrap segment_mean topi compute"""

def _compute_segment_mean(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "mean")]

return _compute_segment_mean


@override_native_generic_func("segment_mean_strategy")
def segment_mean_strategy(attrs, inputs, out_type, target):
"""segment_mean generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_mean(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_mean.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

Expand All @@ -1594,8 +1663,54 @@ def unique_strategy(attrs, inputs, out_type, target):
return strategy


# segment_sum
def wrap_compute_segment_sum(topi_compute):
"""wrap segment_sum topi compute"""

def _compute_segment_sum(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "sum")]

return _compute_segment_sum


@override_native_generic_func("segment_sum_strategy")
def segment_sum_strategy(attrs, inputs, out_type, target):
"""segment_sum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_sum(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_sum.generic",
)
return strategy


@generic_func
def schedule_transpose(attrs, outs, target):
"""schedule transpose"""
with target:
return schedule_injective(attrs, outs, target)


# segment_prod
def wrap_compute_segment_prod(topi_compute):
"""wrap segment_prod topi compute"""

def _compute_segment_prod(attrs, inputs, out_type):
num_segments = attrs.num_segments
return [topi_compute(inputs[0], inputs[1], num_segments, "prod")]

return _compute_segment_prod


@override_native_generic_func("segment_prod_strategy")
def segment_prod_strategy(attrs, inputs, out_type, target):
"""segment_prod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_segment_prod(topi.segment_op),
wrap_topi_schedule(topi.generic.schedule_segment_op),
name="segment_prod.generic",
)
return strategy
110 changes: 110 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,113 @@ def isinf(data):
The computed result.
"""
return _make.isinf(data)


def segment_max(data, segment_ids, num_segments):
"""Computes the maximum along segments of a tensor.

Parameters
----------
data : relay.Expr
The input data

segment_ids : relay.Expr
The segments data

num_segments : int
The maximum of segment_ids.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_max(data, segment_ids, num_segments)


def segment_min(data, segment_ids, num_segments):
"""Computes the minimum along segments of a tensor.

Parameters
----------
data : relay.Expr
The input data

segment_ids : relay.Expr
The segments data

num_segments : int
The maximum of segment_ids.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_min(data, segment_ids, num_segments)


def segment_mean(data, segment_ids, num_segments):
"""Computes the mean along segments of a tensor.

Parameters
----------
data : relay.Expr
The input data

segment_ids : relay.Expr
The segments data

num_segments : int
The maximum of segment_ids.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_mean(data, segment_ids, num_segments)


def segment_sum(data, segment_ids, num_segments):
"""Computes the sum along segments of a tensor.

Parameters
----------
data : relay.Expr
The input data

segment_ids : relay.Expr
The segments data

num_segments : int
The maximum of segment_ids.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_sum(data, segment_ids, num_segments)


def segment_prod(data, segment_ids, num_segments):
"""Computes the prod along segments of a tensor.

Parameters
----------
data : relay.Expr
The input data

segment_ids : relay.Expr
The segments data

num_segments : int
The maximum of segment_ids.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.segment_prod(data, segment_ids, num_segments)
Loading