From 10a7a17a10fc20a6ae76becb2f1f1fb453a7d5dc Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 11 Nov 2021 13:51:52 +0800 Subject: [PATCH] remove redundancy func in setup --- setup.py | 98 +++++++++++++++++--------------------------------------- 1 file changed, 29 insertions(+), 69 deletions(-) diff --git a/setup.py b/setup.py index e68430a210d8..d71876bb9938 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ import os import subprocess import sys -import warnings import torch from setuptools import setup, find_packages @@ -23,13 +22,36 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version( + cuda_dir) + torch_binary_major = torch.version.cuda.split(".")[0] + torch_binary_minor = torch.version.cuda.split(".")[1] + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " + + "not match the version used to compile Pytorch binaries. " + + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk).") + + +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). print('\nWarning: Torch did not find available GPUs on this system.\n', 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' + 'By default, Colossal-AI will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n' 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n' 'If you wish to cross-compile for a single specific architecture,\n' @@ -46,66 +68,12 @@ def get_cuda_bare_metal_version(cuda_dir): TORCH_MINOR = int(torch.__version__.split('.')[1]) if TORCH_MAJOR == 0 and TORCH_MINOR < 4: - raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" + + raise RuntimeError("Colossal-AI requires Pytorch 0.4 or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/") cmdclass = {} ext_modules = [] -extras = {} -if "--pyprof" in sys.argv: - string = "\n\nPyprof has been moved to its own dedicated repository and will " + \ - "soon be removed from Apex. Please visit\n" + \ - "https://github.com/NVIDIA/PyProf\n" + \ - "for the latest version." - warnings.warn(string, DeprecationWarning) - with open('requirements.txt') as f: - required_packages = f.read().splitlines() - extras['pyprof'] = required_packages - try: - sys.argv.remove("--pyprof") - except: - pass -else: - warnings.warn( - "Option --pyprof not specified. Not installing PyProf dependencies!") - -if "--cuda_ext" in sys.argv: - if TORCH_MAJOR == 0: - raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, " - "found torch.__version__ = {}".format(torch.__version__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version( - cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): - raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " + - "not match the version used to compile Pytorch binaries. " + - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + - "In some cases, a minor-version mismatch will not cause later errors: " + - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk).") - - # Set up macros for forward/backward compatibility hack around # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # and @@ -123,6 +91,10 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 if "--cuda_ext" in sys.argv: + if TORCH_MAJOR == 0: + raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, " + "found torch.__version__ = {}".format(torch.__version__)) + sys.argv.remove("--cuda_ext") if CUDA_HOME is None: @@ -145,17 +117,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): # '--resource-usage', '--use_fast_math'] + version_dependent_macros})) -# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): - generator_flag = ['-DOLD_GENERATOR'] - - -def fetch_requirements(path): - with open(path, 'r') as fd: - return [r.strip() for r in fd.readlines()] - install_requires = fetch_requirements('requirements/requirements.txt') @@ -170,6 +131,5 @@ def fetch_requirements(path): description='An integrated large-scale model training system with efficient parallelization techniques', ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - extras_require=extras, install_requires=install_requires, )