Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 97 additions & 28 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
def _get_targets(target_str=None):
if target_str is None:
target_str = os.environ.get("TVM_TEST_TARGETS", "")
# Use dict instead of set for de-duplication so that the
# targets stay in the order specified.
target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()})

if len(target_str) == 0:
target_str = DEFAULT_TEST_TARGETS

target_names = set(t.strip() for t in target_str.split(";") if t.strip())
if not target_names:
target_names = DEFAULT_TEST_TARGETS

targets = []
for target in target_names:
Expand Down Expand Up @@ -413,10 +414,18 @@ def _get_targets(target_str=None):
return targets


DEFAULT_TEST_TARGETS = (
"llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
)
DEFAULT_TEST_TARGETS = [
"llvm",
"llvm -device=arm_cpu",
"cuda",
"nvptx",
"vulkan -from_device=0",
"opencl",
"opencl -device=mali,aocl_sw_emu",
"opencl -device=intel_graphics",
"metal",
"rocm",
]


def device_enabled(target):
Expand Down Expand Up @@ -730,20 +739,25 @@ def requires_rpc(*args):


def _target_to_requirement(target):
if isinstance(target, str):
target_kind = target.split()[0]
else:
target_kind = target.kind.name

# mapping from target to decorator
if target.startswith("cuda"):
if target_kind == "cuda":
return requires_cuda()
if target.startswith("rocm"):
if target_kind == "rocm":
return requires_rocm()
if target.startswith("vulkan"):
if target_kind == "vulkan":
return requires_vulkan()
if target.startswith("nvptx"):
if target_kind == "nvptx":
return requires_nvptx()
if target.startswith("metal"):
if target_kind == "metal":
return requires_metal()
if target.startswith("opencl"):
if target_kind == "opencl":
return requires_opencl()
if target.startswith("llvm"):
if target_kind == "llvm":
return requires_llvm()
return []

Expand Down Expand Up @@ -794,16 +808,74 @@ def _auto_parametrize_target(metafunc):
file.

"""

def update_parametrize_target_arg(
argnames,
argvalues,
*args,
**kwargs,
):
args = [arg.strip() for arg in argnames.split(",") if arg.strip()]
if "target" in args:
target_i = args.index("target")

new_argvalues = []
for argvalue in argvalues:

if isinstance(argvalue, _pytest.mark.structures.ParameterSet):
# The parametrized value is already a
# pytest.param, so track any marks already
# defined.
param_set = argvalue.values
target = param_set[target_i]
additional_marks = argvalue.marks
elif len(args) == 1:
# Single value parametrization, argvalue is a list of values.
target = argvalue
param_set = (target,)
additional_marks = []
else:
# Multiple correlated parameters, argvalue is a list of tuple of values.
param_set = argvalue
target = param_set[target_i]
additional_marks = []

new_argvalues.append(
pytest.param(
*param_set, marks=_target_to_requirement(target) + additional_marks
)
)

try:
argvalues[:] = new_argvalues
except TypeError as e:
pyfunc = metafunc.definition.function
filename = pyfunc.__code__.co_filename
line_number = pyfunc.__code__.co_firstlineno
msg = (
f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) "
"is parametrized using a tuple of parameters instead of a list "
"of parameters."
)
raise TypeError(msg) from e

if "target" in metafunc.fixturenames:
# Update any explicit use of @pytest.mark.parmaetrize to
# parametrize over targets. This adds the appropriate
# @tvm.testing.requires_* markers for each target.
for mark in metafunc.definition.iter_markers("parametrize"):
update_parametrize_target_arg(*mark.args, **mark.kwargs)

# Check if any explicit parametrizations exist, and apply one
# if they do not. If the function is marked with either
# excluded or known failing targets, use these to determine
# the targets to be used.
parametrized_args = [
arg.strip()
for mark in metafunc.definition.iter_markers("parametrize")
for arg in mark.args[0].split(",")
]

if "target" not in parametrized_args:
# Check if the function is marked with either excluded or
# known failing targets.
excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", [])
xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", [])
metafunc.parametrize(
Expand Down Expand Up @@ -849,17 +921,14 @@ def parametrize_targets(*args):
>>> ... # do something
"""

def wrap(targets):
def func(f):
return pytest.mark.parametrize(
"target", _pytest_target_params(targets), scope="session"
)(f)

return func

# Backwards compatibility, when used as a decorator with no
# arguments implicitly parametrizes over "target". The
# parametrization is now handled by _auto_parametrize_target, so
# this use case can just return the decorated function.
if len(args) == 1 and callable(args[0]):
return wrap(None)(args[0])
return wrap(args)
return args[0]

return pytest.mark.parametrize("target", list(args), scope="session")


def exclude_targets(*args):
Expand Down
Loading