From 7ce8a0304d1b810f05b882ac73e52500b5726058 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 4 Dec 2020 12:38:02 +0100 Subject: [PATCH 1/2] Add CUDA 11.1 libdevice Maybe we should have a >= check instead. I also added a fallback to detect the version if version.txt is missing. Calling nvcc for this has been inspired by what PyTorch does when compiling extension modules. --- python/tvm/contrib/nvcc.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 53a507f2d79a..647daa6f1ab5 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -149,7 +149,19 @@ def get_cuda_version(cuda_path): version_str = f.readline().replace("\n", "").replace("\r", "") return float(version_str.split(" ")[2][:2]) except: - raise RuntimeError("Cannot read cuda version file") + pass + + cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + out = py_str(out) + if proc.returncode == 0: + release_line = [l for l in out.split("\n") if "release" in l][0] + release_fields = [s.strip() for s in release_line.split(",")] + release_version = [f[1:] for f in release_fields if f.startswith("V")][0] + major_minor = ".".join(release_version.split(".")[:2]) + return float(major_minor) + raise RuntimeError("Cannot read cuda version file") @tvm._ffi.register_func("tvm_callback_libdevice_path") @@ -174,7 +186,7 @@ def find_libdevice_path(arch): selected_ver = 0 selected_path = None cuda_ver = get_cuda_version(cuda_path) - if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0): + if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1): path = os.path.join(lib_path, "libdevice.10.bc") else: for fn in os.listdir(lib_path): From 4f69b2505d3c393fb98fc8b3a15f72a681e19961 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 4 Dec 2020 14:29:50 +0100 Subject: [PATCH 2/2] fix other lint --- python/tvm/contrib/nvcc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 647daa6f1ab5..89548b74866b 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -148,7 +148,7 @@ def get_cuda_version(cuda_path): with open(version_file_path) as f: version_str = f.readline().replace("\n", "").replace("\r", "") return float(version_str.split(" ")[2][:2]) - except: + except FileNotFoundError: pass cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"] @@ -231,6 +231,7 @@ def parse_compute_version(compute_version): minor = int(split_ver[1]) return major, minor except (IndexError, ValueError) as err: + # pylint: disable=raise-missing-from raise RuntimeError("Compute version parsing error: " + str(err))