From e317ad902e54f7f90b3030e3ea3fea9ceec5088c Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Sat, 19 Sep 2020 21:10:44 -0700 Subject: [PATCH 1/2] use runtime Features instead of manual check --- tools/pip/setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/pip/setup.py b/tools/pip/setup.py index bf007ea1f5a5..026e7515b1ed 100644 --- a/tools/pip/setup.py +++ b/tools/pip/setup.py @@ -24,6 +24,7 @@ import shutil import platform from setuptools import setup, find_packages +from mxnet import runtime if platform.system() == 'Linux': sys.argv.append('--python-tag') @@ -135,14 +136,14 @@ def skip_markdown_comments(md): elif variant.startswith('CU92'): libraries.append('CUDA-9.2') -if variant != 'NATIVE': +if runtime.Features().is_enabled("MKLDNN"): libraries.append('MKLDNN') short_description += ' This version uses {0}.'.format(' and '.join(libraries)) package_data = {'mxnet': [os.path.join('mxnet', os.path.basename(LIB_PATH[0]))], 'dmlc_tracker': []} -if variant != 'NATIVE': +if runtime.Features().is_enabled("MKLDNN"): shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/mkldnn/include'), os.path.join(CURRENT_DIR, 'mxnet/include/mkldnn')) if platform.system() == 'Linux': From 1a6234f245001599f76f30cdca3aa0782f548d6e Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 23 Sep 2020 18:58:20 -0700 Subject: [PATCH 2/2] move mxnet package closer to where its used --- tools/pip/setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/pip/setup.py b/tools/pip/setup.py index 026e7515b1ed..2aff243d3ed7 100644 --- a/tools/pip/setup.py +++ b/tools/pip/setup.py @@ -24,7 +24,6 @@ import shutil import platform from setuptools import setup, find_packages -from mxnet import runtime if platform.system() == 'Linux': sys.argv.append('--python-tag') @@ -136,14 +135,15 @@ def skip_markdown_comments(md): elif variant.startswith('CU92'): libraries.append('CUDA-9.2') -if runtime.Features().is_enabled("MKLDNN"): +from mxnet.runtime import Features +if Features().is_enabled("MKLDNN"): libraries.append('MKLDNN') short_description += ' This version uses {0}.'.format(' and '.join(libraries)) package_data = {'mxnet': [os.path.join('mxnet', os.path.basename(LIB_PATH[0]))], 'dmlc_tracker': []} -if runtime.Features().is_enabled("MKLDNN"): +if Features().is_enabled("MKLDNN"): shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/mkldnn/include'), os.path.join(CURRENT_DIR, 'mxnet/include/mkldnn')) if platform.system() == 'Linux':