From 8cc9567f4dd3b838ddc4bbe89e31e14014761861 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 Sep 2017 15:42:58 -0700 Subject: [PATCH] [CUDA] auto detect compatibility when arch is not passed --- include/tvm/runtime/device_api.h | 3 +- python/tvm/_ffi/runtime_ctypes.py | 14 ++++++++ python/tvm/contrib/cc.py | 5 +-- python/tvm/contrib/nvcc.py | 40 +++++++++++---------- src/runtime/cuda/cuda_device_api.cc | 11 ++++++ src/runtime/metal/metal_device_api.mm | 1 + src/runtime/opencl/opencl_device_api.cc | 1 + src/runtime/rocm/rocm_device_api.cc | 1 + topi/python/topi/transform.py | 2 +- topi/recipe/broadcast/test_broadcast_map.py | 2 +- topi/recipe/conv/depthwise_conv2d_test.py | 2 +- topi/recipe/conv/test_conv2d_hwcn_map.py | 2 +- topi/recipe/gemm/cuda_gemm_square.py | 2 +- topi/recipe/reduce/test_reduce_map.py | 2 +- topi/recipe/rnn/lstm.py | 2 +- topi/recipe/rnn/matexp.py | 2 +- 16 files changed, 61 insertions(+), 31 deletions(-) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index ccf363c54457..d50a23c80df7 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -18,7 +18,8 @@ namespace runtime { enum DeviceAttrKind : int { kExist = 0, kMaxThreadsPerBlock = 1, - kWarpSize = 2 + kWarpSize = 2, + kComputeVersion = 3 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index c5eaa43208e0..9ea8ef579e10 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -131,6 +131,20 @@ def warp_size(self): return _api_internal._GetDeviceAttr( self.device_type, self.device_id, 2) + @property + def compute_version(self): + """Get compute verison number in string. + + Currently used to get compute capability of CUDA device. + + Returns + ------- + version : str + The version string in `major.minor` format. + """ + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 3) + def sync(self): """Synchronize until jobs finished at the context.""" check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index f342b81f5bbf..d9379bbd7efd 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -39,11 +39,8 @@ def create_shared(output, if options: cmd += options - args = ' '.join(cmd) proc = subprocess.Popen( - args, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() if proc.returncode != 0: diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index a5bf6254fd32..9651466b723d 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -1,12 +1,16 @@ # pylint: disable=invalid-name """Utility to invoke nvcc compiler in the system""" from __future__ import absolute_import as _abs -import sys + import subprocess from . import util +from .. import ndarray as nd -def compile_cuda(code, target="ptx", arch=None, - options=None, path_target=None): +def compile_cuda(code, + target="ptx", + arch=None, + options=None, + path_target=None): """Compile cuda code with NVCC from env. Parameters @@ -39,32 +43,32 @@ def compile_cuda(code, target="ptx", arch=None, with open(temp_code, "w") as out_file: out_file.write(code) - if target == "cubin" and arch is None: - raise ValueError("arch(sm_xy) must be passed for generating cubin") + + if arch is None: + if nd.gpu(0).exist: + # auto detect the compute arch argument + arch = "sm_" + "".join(nd.gpu(0).compute_version.split('.')) + else: + raise ValueError("arch(sm_xy) is not passed, and we cannot detect it from env") file_target = path_target if path_target else temp_target cmd = ["nvcc"] cmd += ["--%s" % target, "-O3"] - if arch: - cmd += ["-arch", arch] + cmd += ["-arch", arch] cmd += ["-o", file_target] if options: cmd += options cmd += [temp_code] - args = ' '.join(cmd) proc = subprocess.Popen( - args, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() if proc.returncode != 0: - sys.stderr.write("Compilation error:\n") - sys.stderr.write(str(out)) - sys.stderr.flush() - cubin = None - else: - cubin = bytearray(open(file_target, "rb").read()) - return cubin + msg = "Compilation error:\n" + msg += out + raise RuntimeError(msg) + + return bytearray(open(file_target, "rb").read()) diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index fa214c6780cc..340b286d87ca 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -40,6 +40,17 @@ class CUDADeviceAPI final : public DeviceAPI { &value, cudaDevAttrWarpSize, ctx.device_id)); break; } + case kComputeVersion: { + std::ostringstream os; + CUDA_CALL(cudaDeviceGetAttribute( + &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); + os << value << "."; + CUDA_CALL(cudaDeviceGetAttribute( + &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); + os << value; + *rv = os.str(); + return; + } } *rv = value; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 0d1a0c666fe1..4af274da98a3 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -39,6 +39,7 @@ *rv = 1; break; } + case kComputeVersion: return; case kExist: break; } } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index abb0a500b77d..f70207ebe881 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -45,6 +45,7 @@ void OpenCLWorkspace::GetAttr( *rv = 1; break; } + case kComputeVersion: return; case kExist: break; } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c8a8bfba8d27..3edb1c67c4d4 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI { value = 64; break; } + case kComputeVersion: return; } *rv = value; } diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 7aec532c45af..4a9b7346c459 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -143,7 +143,7 @@ def _compute(begin, *indices): begin_ids = [seg_size * i for i in range(indices_or_sections)] elif isinstance(indices_or_sections, (tuple, list)): assert tuple(indices_or_sections) == tuple(sorted(indices_or_sections)),\ - "Should be sorted, recieved %s" %str(indices_or_sections) + "Should be sorted, recieved %s" % str(indices_or_sections) begin_ids = [0] + list(indices_or_sections) else: raise NotImplementedError diff --git a/topi/recipe/broadcast/test_broadcast_map.py b/topi/recipe/broadcast/test_broadcast_map.py index 6d5c3f1de3e8..9c4e521ddd0d 100644 --- a/topi/recipe/broadcast/test_broadcast_map.py +++ b/topi/recipe/broadcast/test_broadcast_map.py @@ -12,7 +12,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx diff --git a/topi/recipe/conv/depthwise_conv2d_test.py b/topi/recipe/conv/depthwise_conv2d_test.py index 64dc10e11158..4d3a20f5c2ed 100644 --- a/topi/recipe/conv/depthwise_conv2d_test.py +++ b/topi/recipe/conv/depthwise_conv2d_test.py @@ -13,7 +13,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) # 37 for k80(ec2 instance) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx def write_code(code, fname): diff --git a/topi/recipe/conv/test_conv2d_hwcn_map.py b/topi/recipe/conv/test_conv2d_hwcn_map.py index a6b9017a74eb..f7cba0934627 100644 --- a/topi/recipe/conv/test_conv2d_hwcn_map.py +++ b/topi/recipe/conv/test_conv2d_hwcn_map.py @@ -12,7 +12,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx def write_code(code, fname): diff --git a/topi/recipe/gemm/cuda_gemm_square.py b/topi/recipe/gemm/cuda_gemm_square.py index f27d6a74d883..0c7ba71a86f6 100644 --- a/topi/recipe/gemm/cuda_gemm_square.py +++ b/topi/recipe/gemm/cuda_gemm_square.py @@ -9,7 +9,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx def write_code(code, fname): diff --git a/topi/recipe/reduce/test_reduce_map.py b/topi/recipe/reduce/test_reduce_map.py index 7cdf2ac61970..6e9befaff2ec 100644 --- a/topi/recipe/reduce/test_reduce_map.py +++ b/topi/recipe/reduce/test_reduce_map.py @@ -12,7 +12,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx diff --git a/topi/recipe/rnn/lstm.py b/topi/recipe/rnn/lstm.py index 7830493eeed1..53ccbe598c3d 100644 --- a/topi/recipe/rnn/lstm.py +++ b/topi/recipe/rnn/lstm.py @@ -17,7 +17,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): """Use nvcc compiler for better perf.""" - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx def write_code(code, fname): diff --git a/topi/recipe/rnn/matexp.py b/topi/recipe/rnn/matexp.py index 8712dea697b8..128dbef9ab13 100644 --- a/topi/recipe/rnn/matexp.py +++ b/topi/recipe/rnn/matexp.py @@ -24,7 +24,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): """Use nvcc compiler for better perf.""" - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx") return ptx def write_code(code, fname):