diff --git a/apps/topi_recipe/gemm/cuda_gemm_square.py b/apps/topi_recipe/gemm/cuda_gemm_square.py index 25d14f9abdf3..0d548dc0b554 100644 --- a/apps/topi_recipe/gemm/cuda_gemm_square.py +++ b/apps/topi_recipe/gemm/cuda_gemm_square.py @@ -21,6 +21,7 @@ from tvm.contrib import nvcc from tvm.contrib import spirv import numpy as np +import tvm.testing TASK = "gemm" USE_MANUAL_CODE = False diff --git a/python/tvm/contrib/thrust.py b/python/tvm/contrib/thrust.py new file mode 100644 index 000000000000..7fe0077c2b42 --- /dev/null +++ b/python/tvm/contrib/thrust.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities for thrust""" +import logging + +from tvm._ffi import get_global_func + + +def maybe_warn(target, func_name): + if get_global_func(func_name, allow_missing=True) and not "thrust" in target.libs: + logging.warning("TVM is built with thrust but thrust is not used.") + if "thrust" in target.libs and get_global_func(func_name, allow_missing=True) is None: + logging.warning("thrust is requested but TVM is not built with thrust.") + + +def can_use_thrust(target, func_name): + maybe_warn(target, func_name) + return ( + target.kind.name in ["cuda", "nvptx"] + and "thrust" in target.libs + and get_global_func(func_name, allow_missing=True) + ) + + +def can_use_rocthrust(target, func_name): + maybe_warn(target, func_name) + return ( + target.kind.name == "rocm" + and "thrust" in target.libs + and get_global_func(func_name, allow_missing=True) + ) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index cb4688c4889e..20c5f03b9b0b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -20,7 +20,7 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib import nvcc -from tvm._ffi import get_global_func +from tvm.contrib.thrust import can_use_thrust from .generic import * from .. import op as _op @@ -791,9 +791,7 @@ def scatter_cuda(attrs, inputs, out_type, target): rank = len(inputs[0].shape) with SpecializedCondition(rank == 1): - if target.kind.name == "cuda" and get_global_func( - "tvm.contrib.thrust.stable_sort_by_key", allow_missing=True - ): + if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter_via_sort), wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), @@ -838,9 +836,7 @@ def sort_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_sort), name="sort.cuda", ) - if target.kind.name == "cuda" and get_global_func( - "tvm.contrib.thrust.sort", allow_missing=True - ): + if can_use_thrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_sort(topi.cuda.sort_thrust), wrap_topi_schedule(topi.cuda.schedule_sort), @@ -859,9 +855,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_argsort), name="argsort.cuda", ) - if target.kind.name == "cuda" and get_global_func( - "tvm.contrib.thrust.sort", allow_missing=True - ): + if can_use_thrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_argsort(topi.cuda.argsort_thrust), wrap_topi_schedule(topi.cuda.schedule_argsort), @@ -880,9 +874,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_topk), name="topk.cuda", ) - if target.kind.name == "cuda" and get_global_func( - "tvm.contrib.thrust.sort", allow_missing=True - ): + if can_use_thrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_topk(topi.cuda.topk_thrust), wrap_topi_schedule(topi.cuda.schedule_topk), diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 934f38625fd3..f4538071e11e 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -19,7 +19,8 @@ from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition -from tvm._ffi import get_global_func +from tvm.contrib.thrust import can_use_rocthrust + from .generic import * from .. import op as _op from .cuda import judge_winograd, naive_schedule @@ -223,14 +224,6 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): return strategy -def can_use_thrust(target, func_name): - return ( - target.kind.name == "rocm" - and "thrust" in target.libs - and get_global_func(func_name, allow_missing=True) - ) - - @argsort_strategy.register(["rocm"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort rocm strategy""" @@ -240,7 +233,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_argsort), name="argsort.rocm", ) - if can_use_thrust(target, "tvm.contrib.thrust.sort"): + if can_use_rocthrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_argsort(topi.cuda.argsort_thrust), wrap_topi_schedule(topi.cuda.schedule_argsort), @@ -264,7 +257,7 @@ def scatter_cuda(attrs, inputs, out_type, target): rank = len(inputs[0].shape) with SpecializedCondition(rank == 1): - if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"): + if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter_via_sort), wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), @@ -283,7 +276,7 @@ def sort_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_sort), name="sort.rocm", ) - if can_use_thrust(target, "tvm.contrib.thrust.sort"): + if can_use_rocthrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_sort(topi.cuda.sort_thrust), wrap_topi_schedule(topi.cuda.schedule_sort), @@ -303,7 +296,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): name="topk.rocm", ) - if can_use_thrust(target, "tvm.contrib.thrust.sort"): + if can_use_rocthrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_topk(topi.cuda.topk_thrust), wrap_topi_schedule(topi.cuda.schedule_topk), diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 98cb6750408a..a5a9c4def526 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -19,9 +19,9 @@ """Non-maximum suppression operator""" import tvm from tvm import te - +from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust from tvm.tir import if_then_else -from .sort import argsort, argsort_thrust, is_thrust_available +from .sort import argsort, argsort_thrust from .scan import exclusive_scan from ..utils import ceil_div @@ -610,8 +610,10 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): ) target = tvm.target.Target.current() - # TODO(masahi): Check -libs=thrust option - if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available(): + if target and ( + can_use_thrust(target, "tvm.contrib.thrust.sort") + or can_use_rocthrust(target, "tvm.contrib.thrust.sort") + ): sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") else: sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 65d23365dc15..84ab5dcf9756 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -18,7 +18,7 @@ "Scan related operators" import tvm from tvm import te -from tvm._ffi import get_global_func +from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust from ..transform import expand_dims, squeeze, transpose, reshape from ..utils import ceil_div, swap, prod, get_const_int from ..math import cast @@ -249,11 +249,6 @@ def ir(data, data_ex_scan, reduction): return reduction -def is_thrust_available(): - """Test if thrust based scan ops are available.""" - return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None - - def scan_thrust( data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add ): @@ -352,8 +347,10 @@ def exclusive_scan( def do_scan(data, output_dtype): target = tvm.target.Target.current() - # TODO(masahi): Check -libs=thrust option - if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available(): + if target and ( + can_use_thrust(target, "tvm.contrib.thrust.sum_scan") + or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan") + ): return scan_thrust( data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop ) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 444fb25cc34b..fd05904ba8e7 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -21,7 +21,7 @@ from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add -from .sort import stable_sort_by_key_thrust, is_thrust_available +from .sort import stable_sort_by_key_thrust from ..utils import prod, ceil_div @@ -565,7 +565,6 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0): if axis < 0: axis += len(data.shape) assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input" - assert is_thrust_available(), "Thrust is required for this op" cfg.add_flop(1) # A dummy value to satisfy AutoTVM diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index c0f076fb6065..ff5cc0681ad2 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -18,7 +18,6 @@ """Sort related operators """ import tvm from tvm import te -from tvm._ffi import get_global_func from .injective import schedule_injective_from_existing from ..transform import strided_slice, transpose @@ -879,10 +878,3 @@ def stable_sort_by_key_thrust(keys, values, for_scatter=False): tag="stable_sort_by_key", ) return out[0], out[1] - - -def is_thrust_available(): - """ - Test if thrust based sorting ops are available. - """ - return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py index 521c20de6cbd..4edce0d6a642 100644 --- a/tests/python/contrib/test_thrust.py +++ b/tests/python/contrib/test_thrust.py @@ -17,16 +17,16 @@ import tvm import tvm.testing from tvm import te -from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available +from tvm.topi.cuda import stable_sort_by_key_thrust from tvm.topi.cuda.scan import exclusive_scan, scan_thrust, schedule_scan +from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust import numpy as np -def test_stable_sort_by_key(): - if not is_thrust_available(): - print("skip because thrust is not enabled...") - return +thrust_check_func = {"cuda": can_use_thrust, "rocm": can_use_rocthrust} + +def test_stable_sort_by_key(): size = 6 keys = te.placeholder((size,), name="keys", dtype="int32") values = te.placeholder((size,), name="values", dtype="int32") @@ -38,74 +38,73 @@ def test_stable_sort_by_key(): print("Skip because %s is not enabled" % target) continue - target += " -libs=thrust" - ctx = tvm.context(target, 0) - s = te.create_schedule([keys_out.op, values_out.op]) - f = tvm.build(s, [keys, values, keys_out, values_out], target) + with tvm.target.Target(target + " -libs=thrust") as tgt: + if not thrust_check_func[target](tgt, "tvm.contrib.thrust.stable_sort_by_key"): + print("skip because thrust is not enabled...") + return - keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) - values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) - keys_np_out = np.zeros(keys_np.shape, np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) - keys_in = tvm.nd.array(keys_np, ctx) - values_in = tvm.nd.array(values_np, ctx) - keys_out = tvm.nd.array(keys_np_out, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(keys_in, values_in, keys_out, values_out) + ctx = tvm.context(target, 0) + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) - ref_keys_out = np.sort(keys_np) - ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) - tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) -def test_exclusive_scan(): - if not is_thrust_available(): - print("skip because thrust is not enabled...") - return +def test_exclusive_scan(): for target in ["cuda", "rocm"]: if not tvm.testing.device_enabled(target): print("Skip because %s is not enabled" % target) continue - target += " -libs=thrust" - for ishape in [(10,), (10, 10), (10, 10, 10)]: - values = te.placeholder(ishape, name="values", dtype="int32") + with tvm.target.Target(target + " -libs=thrust") as tgt: + if not thrust_check_func[target](tgt, "tvm.contrib.thrust.sum_scan"): + print("skip because thrust is not enabled...") + return + + for ishape in [(10,), (10, 10), (10, 10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") - with tvm.target.Target(target): scan, reduction = exclusive_scan(values, return_reduction=True) s = schedule_scan([scan, reduction]) - ctx = tvm.context(target, 0) - f = tvm.build(s, [values, scan, reduction], target) + ctx = tvm.context(target, 0) + f = tvm.build(s, [values, scan, reduction], target) - values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) - if len(ishape) == 1: - reduction_shape = () - else: - reduction_shape = ishape[:-1] + if len(ishape) == 1: + reduction_shape = () + else: + reduction_shape = ishape[:-1] - reduction_np_out = np.zeros(reduction_shape, np.int32) + reduction_np_out = np.zeros(reduction_shape, np.int32) - values_in = tvm.nd.array(values_np, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - reduction_out = tvm.nd.array(reduction_np_out, ctx) - f(values_in, values_out, reduction_out) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + reduction_out = tvm.nd.array(reduction_np_out, ctx) + f(values_in, values_out, reduction_out) - ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) - ref_reduction_out = np.sum(values_np, axis=-1) - tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) + ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + ref_reduction_out = np.sum(values_np, axis=-1) + tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) def test_inclusive_scan(): - if not is_thrust_available(): - print("skip because thrust is not enabled...") - return - out_dtype = "int64" for target in ["cuda", "rocm"]: @@ -113,25 +112,28 @@ def test_inclusive_scan(): print("Skip because %s is not enabled" % target) continue - target += " -libs=thrust" - for ishape in [(10,), (10, 10)]: - values = te.placeholder(ishape, name="values", dtype="int32") + with tvm.target.Target(target + " -libs=thrust") as tgt: + if not thrust_check_func[target](tgt, "tvm.contrib.thrust.sum_scan"): + print("skip because thrust is not enabled...") + return + + for ishape in [(10,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") - with tvm.target.Target(target): scan = scan_thrust(values, out_dtype, exclusive=False) s = tvm.te.create_schedule([scan.op]) - ctx = tvm.context(target, 0) - f = tvm.build(s, [values, scan], target) + ctx = tvm.context(target, 0) + f = tvm.build(s, [values, scan], target) - values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, out_dtype) - values_in = tvm.nd.array(values_np, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(values_in, values_out) + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, out_dtype) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(values_in, values_out) - ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) if __name__ == "__main__": diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index f1f1bbb7057e..478aff255e0c 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -94,6 +94,10 @@ def build(target): ###################################################################### # Create TVM runtime and do inference +# .. note:: +# +# Use target = "cuda -libs" to enable thrust based sort, if you +# enabled thrust during cmake by -DUSE_THRUST=ON. def run(lib, ctx):