Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
16 changes: 1 addition & 15 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,7 @@ def compute_strided_set(attrs, inputs, output_type):
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(te.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]


_reg.register_schedule("argwhere", strategy.schedule_argwhere)
_reg.register_strategy("argwhere", strategy.argwhere_strategy)

# scatter
@_reg.register_compute("scatter")
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,3 +921,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
name="correlation.cuda",
)
return strategy


@argwhere_strategy.register(["cuda", "gpu"])
def argwhere_strategy_cuda(attrs, inputs, out_type, target):
"""argwhere cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.cuda.argwhere),
wrap_topi_schedule(topi.cuda.schedule_argwhere),
name="argwhere.cuda",
)
return strategy
39 changes: 30 additions & 9 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi, _ffi
from tvm import topi, _ffi, te, ir
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
Expand Down Expand Up @@ -1034,14 +1034,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
return strategy


# argwhere
@generic_func
def schedule_argwhere(attrs, outs, target):
"""schedule argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)


# scatter
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
Expand Down Expand Up @@ -1223,3 +1215,32 @@ def correlation_strategy(attrs, inputs, out_type, target):
name="correlation.generic",
)
return strategy


# argwhere
def wrap_compute_argwhere(topi_compute):
"""wrap argwhere topi compute"""

def _compute_argwhere(attrs, inputs, out_type):
output_shape = []
for s in out_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
output_shape.append(te.var("any_dim", "int32"))
new_output_type = ir.TensorType(output_shape, "int32")
return [topi_compute(new_output_type, inputs[0])]

return _compute_argwhere


@override_native_generic_func("argwhere_strategy")
def argwhere_strategy(attrs, inputs, out_type, target):
"""argwhere generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.argwhere),
wrap_topi_schedule(topi.generic.schedule_argwhere),
name="argwhere.generic",
)
return strategy
2 changes: 2 additions & 0 deletions python/tvm/topi/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
import tvm
from tvm.te import hybrid


Expand Down Expand Up @@ -169,6 +170,7 @@ def hybrid_argwhere_5d(output_shape, condition):
return a


@tvm.target.generic_func
def argwhere(output_shape, condition):
"""Find the indices of elements of a tensor that are non-zero.

Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
from .argwhere import *
Loading