Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/topi_recipe/gemm/cuda_gemm_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm.contrib import nvcc
from tvm.contrib import spirv
import numpy as np
import tvm.testing

TASK = "gemm"
USE_MANUAL_CODE = False
Expand Down
45 changes: 45 additions & 0 deletions python/tvm/contrib/thrust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities for thrust"""
import logging

from tvm._ffi import get_global_func


def maybe_warn(target, func_name):
if get_global_func(func_name, allow_missing=True) and not "thrust" in target.libs:
logging.warning("TVM is built with thrust but thrust is not used.")
if "thrust" in target.libs and get_global_func(func_name, allow_missing=True) is None:
logging.warning("thrust is requested but TVM is not built with thrust.")


def can_use_thrust(target, func_name):
maybe_warn(target, func_name)
return (
target.kind.name in ["cuda", "nvptx"]
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)


def can_use_rocthrust(target, func_name):
maybe_warn(target, func_name)
return (
target.kind.name == "rocm"
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)
18 changes: 5 additions & 13 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from tvm._ffi import get_global_func
from tvm.contrib.thrust import can_use_thrust
from .generic import *
from .. import op as _op

Expand Down Expand Up @@ -791,9 +791,7 @@ def scatter_cuda(attrs, inputs, out_type, target):
rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.stable_sort_by_key", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
Expand Down Expand Up @@ -838,9 +836,7 @@ def sort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
Expand All @@ -859,9 +855,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
Expand All @@ -880,9 +874,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
Expand Down
19 changes: 6 additions & 13 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm._ffi import get_global_func
from tvm.contrib.thrust import can_use_rocthrust

from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
Expand Down Expand Up @@ -223,14 +224,6 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
return strategy


def can_use_thrust(target, func_name):
return (
target.kind.name == "rocm"
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)


@argsort_strategy.register(["rocm"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort rocm strategy"""
Expand All @@ -240,7 +233,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
if can_use_rocthrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
Expand All @@ -264,7 +257,7 @@ def scatter_cuda(attrs, inputs, out_type, target):
rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
Expand All @@ -283,7 +276,7 @@ def sort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
if can_use_rocthrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
Expand All @@ -303,7 +296,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
name="topk.rocm",
)

if can_use_thrust(target, "tvm.contrib.thrust.sort"):
if can_use_rocthrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"""Non-maximum suppression operator"""
import tvm
from tvm import te

from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust, is_thrust_available
from .sort import argsort, argsort_thrust
from .scan import exclusive_scan
from ..utils import ceil_div

Expand Down Expand Up @@ -610,8 +610,10 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
)

target = tvm.target.Target.current()
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
if target and (
can_use_thrust(target, "tvm.contrib.thrust.sort")
or can_use_rocthrust(target, "tvm.contrib.thrust.sort")
):
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought and feel free to ignore: Since THRUST won't be used without -libs=thrust after this PR, would that be better to add a warning message here if tvm.contrib.thrust.sort is available (i.e., users built TVM with THRUST enabled) and target doesn't have -libs=thrust? We can then remove the warning in the next release or so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll do that after I get more feedback

Copy link
Member Author

@masahi masahi Feb 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning added in f688f64
Please have a look

Expand Down
13 changes: 5 additions & 8 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"Scan related operators"
import tvm
from tvm import te
from tvm._ffi import get_global_func
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from ..transform import expand_dims, squeeze, transpose, reshape
from ..utils import ceil_div, swap, prod, get_const_int
from ..math import cast
Expand Down Expand Up @@ -249,11 +249,6 @@ def ir(data, data_ex_scan, reduction):
return reduction


def is_thrust_available():
"""Test if thrust based scan ops are available."""
return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None


def scan_thrust(
data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add
):
Expand Down Expand Up @@ -352,8 +347,10 @@ def exclusive_scan(

def do_scan(data, output_dtype):
target = tvm.target.Target.current()
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
if target and (
can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
):
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available
from .sort import stable_sort_by_key_thrust
from ..utils import prod, ceil_div


Expand Down Expand Up @@ -565,7 +565,6 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0):
if axis < 0:
axis += len(data.shape)
assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
assert is_thrust_available(), "Thrust is required for this op"

cfg.add_flop(1) # A dummy value to satisfy AutoTVM

Expand Down
8 changes: 0 additions & 8 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Sort related operators """
import tvm
from tvm import te
from tvm._ffi import get_global_func

from .injective import schedule_injective_from_existing
from ..transform import strided_slice, transpose
Expand Down Expand Up @@ -879,10 +878,3 @@ def stable_sort_by_key_thrust(keys, values, for_scatter=False):
tag="stable_sort_by_key",
)
return out[0], out[1]


def is_thrust_available():
"""
Test if thrust based sorting ops are available.
"""
return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None
Loading