diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index a78c68ee67c4..6f0d1f440a0f 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -324,11 +324,26 @@ def device(dev_type, dev_id=0): assert tvm.device("cpu", 1) == tvm.cpu(1) assert tvm.device("cuda", 0) == tvm.cuda(0) """ + if not isinstance(dev_id, int): + raise ValueError(f"Invalid device id: {dev_id}") + if isinstance(dev_type, string_types): dev_type = dev_type.split()[0] + if dev_type.count(":") == 0: + pass + elif dev_type.count(":") == 1: + # It will override the dev_id passed by the user. + dev_type, dev_id = dev_type.split(":") + if not dev_id.isdigit(): + raise ValueError(f"Invalid device id: {dev_id}") + dev_id = int(dev_id) + else: + raise ValueError(f"Invalid device string: {dev_type}") + if dev_type not in Device.STR2MASK: - raise ValueError(f"Unknown device type {dev_type}") - dev_type = Device.STR2MASK[dev_type] + raise ValueError(f"Unknown device type: {dev_type}") + + return Device(Device.STR2MASK[dev_type], dev_id) return Device(dev_type, dev_id) diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py new file mode 100644 index 000000000000..5c139cc949c2 --- /dev/null +++ b/python/tvm/target/detect_target.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Detect target.""" +from typing import Union + +from . import Target +from .._ffi import get_global_func +from .._ffi.runtime_ctypes import Device +from ..runtime.ndarray import device + + +def _detect_metal(dev: Device) -> Target: + return Target( + { + "kind": "metal", + "max_shared_memory_per_block": 32768, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + } + ) + + +def _detect_cuda(dev: Device) -> Target: + return Target( + { + "kind": "cuda", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + "arch": "sm_" + dev.compute_version.replace(".", ""), + } + ) + + +def _detect_rocm(dev: Device) -> Target: + return Target( + { + "kind": "rocm", + "mtriple": "amdgcn-and-amdhsa-hcc", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + } + ) + + +def _detect_vulkan(dev: Device) -> Target: + f_get_target_property = get_global_func("device_api.vulkan.get_target_property") + return Target( + { + "kind": "vulkan", + "max_threads_per_block": dev.max_threads_per_block, + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "thread_warp_size": dev.warp_size, + "supports_float16": f_get_target_property(dev, "supports_float16"), + "supports_int16": f_get_target_property(dev, "supports_int16"), + "supports_int8": f_get_target_property(dev, "supports_int8"), + "supports_16bit_buffer": f_get_target_property(dev, "supports_16bit_buffer"), + } + ) + + +def detect_target_from_device(dev: Union[str, Device]) -> Target: + """Detects Target associated with the given device. If the device does not exist, + there will be an Error. + + Parameters + ---------- + dev : Union[str, Device] + The device to detect the target for. + Supported device types: ["cuda", "metal", "rocm", "vulkan"] + + Returns + ------- + target : Target + The detected target. + """ + if isinstance(dev, str): + dev = device(dev) + device_type = Device.MASK2STR[dev.device_type] + if device_type not in SUPPORT_DEVICE: + raise ValueError( + f"Auto detection for device `{device_type}` is not supported. " + f"Currently only supports: {SUPPORT_DEVICE.keys()}" + ) + if not dev.exist: + raise ValueError( + f"Cannot detect device `{dev}`. Please make sure the device and its driver " + "is installed properly, and TVM is compiled with the driver" + ) + target = SUPPORT_DEVICE[device_type](dev) + return target + + +SUPPORT_DEVICE = { + "cuda": _detect_cuda, + "metal": _detect_metal, + "vulkan": _detect_vulkan, + "rocm": _detect_rocm, +} diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index b027b99b17eb..ec74cbcdb62a 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -18,9 +18,11 @@ import json import re import warnings +from typing import Union import tvm._ffi from tvm._ffi import register_func as _register_func +from tvm._ffi.runtime_ctypes import Device from tvm.runtime import Object, convert from tvm.runtime.container import String from tvm.ir.container import Map, Array @@ -148,6 +150,28 @@ def export(self): def with_host(self, host=None): return _ffi_api.WithHost(self, Target(host)) + @staticmethod + def from_device(device: Union[str, Device]) -> "Target": + """Detects Target associated with the given device. If the device does not exist, + there will be an Error. + + Parameters + ---------- + dev : Union[str, Device] + The device to detect the target for. + Supported device types: ["cuda", "metal", "rocm", "vulkan", "opencl", "cpu"] + + Returns + ------- + target : Target + The detected target. + """ + from .detect_target import ( # pylint: disable=import-outside-toplevel + detect_target_from_device, + ) + + return detect_target_from_device(device) + @staticmethod def current(allow_none=True): """Returns the current target. diff --git a/tests/python/unittest/test_device.py b/tests/python/unittest/test_device.py new file mode 100644 index 000000000000..9d10251e1514 --- /dev/null +++ b/tests/python/unittest/test_device.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +import tvm.testing +from tvm._ffi.runtime_ctypes import Device + + +@pytest.mark.parametrize( + "dev_str, expected_device_type, expect_device_id", + [ + ("cpu", Device.kDLCPU, 0), + ("cuda", Device.kDLCUDA, 0), + ("cuda:0", Device.kDLCUDA, 0), + ("cuda:3", Device.kDLCUDA, 3), + ("metal:2", Device.kDLMetal, 2), + ], +) +def test_device(dev_str, expected_device_type, expect_device_id): + dev = tvm.device(dev_str) + assert dev.device_type == expected_device_type + assert dev.device_id == expect_device_id + + +@pytest.mark.parametrize( + "dev_type, dev_id, expected_device_type, expect_device_id", + [ + ("cpu", 0, Device.kDLCPU, 0), + ("cuda", 0, Device.kDLCUDA, 0), + (Device.kDLCUDA, 0, Device.kDLCUDA, 0), + ("cuda", 3, Device.kDLCUDA, 3), + (Device.kDLMetal, 2, Device.kDLMetal, 2), + ], +) +def test_device_with_dev_id(dev_type, dev_id, expected_device_type, expect_device_id): + dev = tvm.device(dev_type=dev_type, dev_id=dev_id) + assert dev.device_type == expected_device_type + assert dev.device_id == expect_device_id + + +@pytest.mark.parametrize( + "dev_type, dev_id", + [ + ("cpu:0:0", None), + ("cpu:?", None), + ("cpu:", None), + (Device.kDLCUDA, "?"), + ], +) +def test_deive_error(dev_type, dev_id): + with pytest.raises(ValueError): + dev = tvm.device(dev_type=dev_type, dev_id=dev_id) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 2b0f1b2dd7a0..da1bbc2c211b 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -488,5 +488,50 @@ def test_target_features(): assert not target_with_features.features.is_missing +@tvm.testing.requires_cuda +@pytest.mark.parametrize("input_device", ["cuda", tvm.cuda()]) +def test_target_from_device_cuda(input_device): + target = Target.from_device(input_device) + + dev = tvm.cuda() + assert target.kind.name == "cuda" + assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block + assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block + assert target.thread_warp_size == dev.warp_size + assert target.arch == "sm_" + dev.compute_version.replace(".", "") + + +@tvm.testing.requires_rocm +@pytest.mark.parametrize("input_device", ["rocm", tvm.rocm()]) +def test_target_from_device_rocm(input_device): + target = Target.from_device(input_device) + + dev = tvm.rocm() + assert target.kind.name == "rocm" + assert target.attrs["mtriple"] == "amdgcn-and-amdhsa-hcc" + assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block + assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block + assert target.thread_warp_size == dev.warp_size + + +@tvm.testing.requires_vulkan +@pytest.mark.parametrize("input_device", ["vulkan", tvm.vulkan()]) +def test_target_from_device_rocm(input_device): + target = Target.from_device(input_device) + + f_get_target_property = tvm.get_global_func("device_api.vulkan.get_target_property") + dev = tvm.vulkan() + assert target.kind.name == "vulkan" + assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block + assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block + assert target.thread_warp_size == dev.warp_size + assert target.attrs["supports_float16"] == f_get_target_property(dev, "supports_float16") + assert target.attrs["supports_int16"] == f_get_target_property(dev, "supports_int16") + assert target.attrs["supports_int8"] == f_get_target_property(dev, "supports_int8") + assert target.attrs["supports_16bit_buffer"] == f_get_target_property( + dev, "supports_16bit_buffer" + ) + + if __name__ == "__main__": tvm.testing.main()