Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,28 @@ def topk_shape_func(attrs, inputs, _):
ret = [indices_out]

return ret


@script
def _searchsorted_shape(sorted_sequence_shape, values_shape):
out_shape = output_tensor((values_shape.shape[0],), "int64")
if sorted_sequence_shape.shape[0] > 1:
assert (
sorted_sequence_shape.shape[0] == values_shape.shape[0]
), "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is not 1-D."
for i in range(values_shape.shape[0]):
if sorted_sequence_shape.shape[0] > 1 and i < values_shape.shape[0] - 1:
assert (
sorted_sequence_shape[i] == values_shape[i]
), "`sorted_sequence and `values` do not have the same shape along outer axes."

out_shape[i] = values_shape[i]
return out_shape


@_reg.register_shape_func("searchsorted", False)
def searchsorted_shape_func(attrs, inputs, _):
"""
Shape func for searchsorted operator.
"""
return [_searchsorted_shape(inputs[0], inputs[1])]
31 changes: 31 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import relay, te
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type
from tvm.topi.testing import searchsorted_ref

from utils import ref_funcs
from utils.assert_diagnostic import DiagnosticTesting
Expand Down Expand Up @@ -2086,5 +2087,35 @@ def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, ax
verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0)


@tvm.testing.uses_gpu
def test_searchsorted():
def verify_searchsorted(
sorted_sequence_shape, values_shape, sorted_sequence_shape_np, values_shape_np
):
x = relay.var("x", relay.TensorType(sorted_sequence_shape, "float32"))
y = relay.var("y", relay.TensorType(values_shape, "float32"))
z = relay.searchsorted(x, y)

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)

x_np = np.sort(np.random.uniform(size=sorted_sequence_shape_np).astype("float32"), axis=-1)
y_np = np.random.uniform(size=values_shape_np).astype("float32")

ref_res = searchsorted_ref(x_np, y_np, False, "int32")
check_result([x_np, y_np], mod, [ref_res])

for shape_np, values_shape_np in zip([(8, 9, 10), (10,), (11,)], [(8, 9, 20), (5,), (8, 9, 7)]):
sorted_sequence_shape = (relay.Any(),) * len(shape_np)
values_shape = (relay.Any(),) * len(values_shape_np)

verify_searchsorted(
sorted_sequence_shape,
values_shape,
shape_np,
values_shape_np,
)


if __name__ == "__main__":
pytest.main([__file__])