From 02ccb5fdc6f8f03a9513935b8d98fe189a3e8cbd Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Fri, 14 Aug 2020 14:31:42 -0700 Subject: [PATCH 1/3] [1.x] Backporting TensorRT and Gluon changes Signed-off-by: Serge Panev --- 3rdparty/onnx-tensorrt | 2 +- CMakeLists.txt | 8 +- .../Dockerfile.build.ubuntu_gpu_tensorrt | 8 +- ci/docker/install/tensorrt.sh | 15 +- ci/docker/runtime_functions.sh | 29 +- ci/jenkins/Jenkins_steps.groovy | 20 +- example/extensions/lib_pass/test_pass.py | 10 +- .../extensions/lib_subgraph/test_subgraph.py | 25 +- include/mxnet/c_api.h | 30 + perl-package/AI-MXNetCAPI/mxnet.i | 11 + python/mxnet/gluon/block.py | 108 +++- python/mxnet/symbol/symbol.py | 168 ++++-- src/c_api/c_api_symbolic.cc | 120 ++-- src/operator/subgraph/build_subgraph.cc | 20 +- .../subgraph/tensorrt/nnvm_to_onnx.cc | 2 +- .../subgraph/tensorrt/onnx_to_tensorrt.cc | 8 +- src/operator/subgraph/tensorrt/tensorrt-inl.h | 48 +- src/operator/subgraph/tensorrt/tensorrt.cu | 6 +- tests/python/tensorrt/lenet5_train.py | 99 ---- tests/python/tensorrt/test_cvnets.py | 174 ------ tests/python/tensorrt/test_ops.py | 517 ------------------ tests/python/tensorrt/test_resnet18.py | 74 --- tests/python/tensorrt/test_tensorrt_lenet5.py | 121 ---- tests/python/unittest/test_extensions.py | 6 +- tests/python/unittest/test_subgraph_op.py | 14 +- 25 files changed, 447 insertions(+), 1196 deletions(-) delete mode 100755 tests/python/tensorrt/lenet5_train.py delete mode 100644 tests/python/tensorrt/test_cvnets.py delete mode 100644 tests/python/tensorrt/test_ops.py delete mode 100644 tests/python/tensorrt/test_resnet18.py delete mode 100644 tests/python/tensorrt/test_tensorrt_lenet5.py diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt index f4745fcaff86..2eb74d933f89 160000 --- a/3rdparty/onnx-tensorrt +++ b/3rdparty/onnx-tensorrt @@ -1 +1 @@ -Subproject commit f4745fcaff868a519834917c657f105a8eef2f53 +Subproject commit 2eb74d933f89e1590fdbfc64971a36e5f72df720 diff --git a/CMakeLists.txt b/CMakeLists.txt index f861686afb49..7e1ef2a00a76 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,6 +239,7 @@ if(USE_TENSORRT) include_directories(3rdparty/onnx-tensorrt/third_party/onnx/) add_definitions(-DMXNET_USE_TENSORRT=1) add_definitions(-DONNX_NAMESPACE=onnx) + add_definitions(-DONNX_ML=1) find_package(Protobuf REQUIRED) @@ -248,14 +249,11 @@ if(USE_TENSORRT) find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED PATHS ${ONNX_PATH} DOC "Path to onnx_proto library.") - find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED - PATHS ${ONNX_TRT_PATH} - DOC "Path to onnx_proto library.") find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED PATHS ${ONNX_TRT_PATH} - DOC "Path to onnx_proto library.") + DOC "Path to onnx_proto parser library.") - list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY} + list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY}) endif() diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt index 90bd772ecb17..9556fee57f03 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt @@ -18,7 +18,7 @@ # # Dockerfile to run MXNet on Ubuntu 16.04 for CPU -FROM nvidia/cuda:10.0-devel +FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 WORKDIR /work/deps @@ -36,12 +36,8 @@ ARG USER_ID=0 COPY install/ubuntu_adduser.sh /work/ RUN /work/ubuntu_adduser.sh -ENV CUDNN_VERSION=7.5.0.56 -COPY install/ubuntu_cudnn.sh /work/ -RUN /work/ubuntu_cudnn.sh - COPY runtime_functions.sh /work/ WORKDIR /work/mxnet ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib -ENV CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.0/targets/x86_64-linux/include/ +ENV CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.2/targets/x86_64-linux/include/ diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh index e98c7643f923..29d8ad1e7c37 100755 --- a/ci/docker/install/tensorrt.sh +++ b/ci/docker/install/tensorrt.sh @@ -18,7 +18,7 @@ # under the License. # Install gluoncv since we're testing Gluon models as well -pip3 install gluoncv==0.2.0 +pip3 install gluoncv==0.4.0 # Install Protobuf # Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt) @@ -40,10 +40,11 @@ popd # Install TensorRT echo "TensorRT build enabled. Installing TensorRT." -wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb -dpkg -i tensorrt.deb apt-get update -apt-get install -y --allow-downgrades libnvinfer5=5.1.5-1+cuda10.0 -apt-get install -y --allow-downgrades libnvinfer-dev=5.1.5-1+cuda10.0 -apt-mark hold libnvinfer5 libnvinfer-dev -rm tensorrt.deb +TRT_VERSION="7.0.0-1+cuda10.2" +TRT_MAJOR_VERSION=7 +apt-get install -y libnvinfer${TRT_MAJOR_VERSION}=${TRT_VERSION} \ + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-plugin${TRT_MAJOR_VERSION}=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} +apt-mark hold libnvinfer${TRT_MAJOR_VERSION} libnvinfer-dev diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 4523e1f017f5..b5cbb9a35122 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -708,6 +708,8 @@ build_ubuntu_gpu_tensorrt() { build_ccache_wrappers + export ONNX_NAMESPACE=onnx + # Build ONNX pushd . echo "Installing ONNX." @@ -715,14 +717,11 @@ build_ubuntu_gpu_tensorrt() { rm -rf build mkdir -p build cd build - cmake \ - -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\ - -DBUILD_SHARED_LIBS=ON ..\ - -G Ninja - ninja -j 1 -v onnx/onnx.proto - ninja -j 1 -v + cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} -DBUILD_SHARED_LIBS=ON .. + make -j$(nproc) export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH + export CXXFLAGS=-I`pwd` popd # Build ONNX-TensorRT @@ -730,15 +729,14 @@ build_ubuntu_gpu_tensorrt() { cd 3rdparty/onnx-tensorrt/ mkdir -p build cd build - cmake .. + cmake -DONNX_NAMESPACE=$ONNX_NAMESPACE .. make -j$(nproc) export LIBRARY_PATH=`pwd`:$LIBRARY_PATH popd mkdir -p /work/mxnet/lib/ cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/ - cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/ - cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/ + cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so* /work/mxnet/lib/ cd /work/build cmake -DUSE_CUDA=1 \ @@ -1071,19 +1069,6 @@ unittest_ubuntu_python3_gpu_nocudnn() { nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu } -unittest_ubuntu_tensorrt_gpu() { - set -ex - export PYTHONPATH=./python/ - export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - export MXNET_SUBGRAPH_VERBOSE=0 - export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH - export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3} - export MXNET_ENABLE_CYTHON=0 - export DMLC_LOG_STACK_TRACE_DEPTH=10 - tests/python/tensorrt/lenet5_train.py - nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/ -} - # quantization gpu currently only runs on P3 instances # need to separte it from unittest_ubuntu_python3_gpu() unittest_ubuntu_python3_quantization_gpu() { diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index c4fd96e65ac0..1cc91e4f4247 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -34,7 +34,7 @@ mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/l mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' +mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser.so*, lib/libonnx_proto.so, lib/libonnx.so' mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so' mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so, build/tests/cpp/mxnet_unit_tests' mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so' @@ -853,24 +853,6 @@ def test_unix_python3_mkldnn_nocudnn_gpu() { }] } -def test_unix_python3_tensorrt_gpu() { - return ['Python3: TensorRT GPU': { - node(NODE_LINUX_GPU_P3) { - ws('workspace/build-tensorrt') { - timeout(time: max_time, unit: 'MINUTES') { - try { - utils.unpack_and_init('tensorrt', mx_tensorrt_lib) - utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true) - utils.publish_test_coverage() - } finally { - utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml') - } - } - } - } - }] -} - def test_unix_python3_integration_gpu() { return ['Python Integration GPU': { node(NODE_LINUX_GPU_G4) { diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py index 01d6eddac67b..a5a7e8c7ea4d 100644 --- a/example/extensions/lib_pass/test_pass.py +++ b/example/extensions/lib_pass/test_pass.py @@ -48,6 +48,7 @@ sym = mx.sym.log(d) def test_model(pass_name): + args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))} # execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') @@ -60,11 +61,10 @@ def test_model(pass_name): # with propogating shapes/types print('-------------------------------') print('Testing pass "%s" with shapes/types' % pass_name) - arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] - aux = [] - mysym2 = sym.optimize_for(pass_name,arg_array,aux) + aux = {} + mysym2 = sym.optimize_for(pass_name,args,aux) print(mysym2.tojson()) - exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() print(out2) @@ -72,7 +72,7 @@ def test_model(pass_name): print('-------------------------------') print('Testing pass "%s" without shapes/types' % pass_name) mysym3 = sym.optimize_for(pass_name, myOpt='yello') - exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe3 = mysym3.bind(ctx=mx.cpu(), args=args) out3 = exe3.forward() print(out3) diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py index 267a417d92f2..5294e1c209a3 100644 --- a/example/extensions/lib_subgraph/test_subgraph.py +++ b/example/extensions/lib_subgraph/test_subgraph.py @@ -49,32 +49,31 @@ sym2 = mx.sym.log(d2) def test(backend): + args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))} ############################################### # Test with subgraph not consuming params ############################################### #execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') - exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe = sym.bind(ctx=mx.cpu(), args=args) out = exe.forward() print(out) # with propogating shapes/types print('-------------------------------') print('Testing %s partitioning with shapes/types' % backend) - arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] - mysym2 = sym.optimize_for(backend,arg_array) + mysym2 = sym.optimize_for(backend,args) print(mysym2.tojson()) - exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() print(out2) # with propogating shapes/types, rejecting subgraph print('-------------------------------') print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend) - arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] - mysym2 = sym.optimize_for(backend, arg_array, reject=True) - exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + mysym2 = sym.optimize_for(backend, args, reject=True) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() print(out2) @@ -82,7 +81,7 @@ def test(backend): print('-------------------------------') print('Testing %s partitioning without shapes/types' % backend) mysym3 = sym.optimize_for(backend, myOpt='yello') - exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe3 = mysym3.bind(ctx=mx.cpu(), args=args) out3 = exe3.forward() print(out3) @@ -115,20 +114,20 @@ def test(backend): ############################################### # Test with subgraph directly consuming params ############################################### + args = {'a':mx.nd.ones((3,2))} #execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') - exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + exe5 = sym2.bind(ctx=mx.cpu(), args=args) out5 = exe5.forward() print(out5) # with propogating shapes/types print('-------------------------------') print('Testing %s partitioning with shapes/types' % backend) - arg_array = [mx.nd.ones((3,2),dtype='float32')] - mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True) + mysym6 = sym2.optimize_for(backend, args, reqArgs=True) print(mysym6.tojson()) - exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + exe6 = mysym6.bind(ctx=mx.cpu(), args=args) out6 = exe6.forward() print(out6) @@ -136,7 +135,7 @@ def test(backend): print('-------------------------------') print('Testing %s partitioning without shapes/types' % backend) mysym7 = sym2.optimize_for(backend, reqArgs=True) - exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + exe7 = mysym7.bind(ctx=mx.cpu(), args=args) out7 = exe7.forward() print(out7) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index cfb2400c9290..98a7a7032e5e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2166,6 +2166,25 @@ MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle * \param num_options number of key value pairs * \param keys keys for options * \param vals values corresponding to keys + * \param num_input_shapes number of input shapes + * \param input_shape_names names of the input shapes + * \param input_shape_data pointer to the contiguous data shapes + * \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape + * is calculate as input_shape_idx[i+1] - input_shape_idx[i] + * \param num_input_dtypes number of input data types + * \param input_dtype_names array of names of the input data types + * \param input_dtypes array of values of the input data types + * \param num_input_stypesnumber of input storage types + * \param input_stype_names array of names of the input storage types + * \param input_stypes array of values of input storage types + * \param skip_infer if the optimization should skip the attribute inferences + * (to use if the backend does not require shape inference) + * \param new_args_cnt pointer a number to store the number of new args + * \param new_args_handle pointer on array to store the new args handles + * \param new_arg_names_handle pointer on array to store the new args names + * \param new_aux_cnt pointer a number to store the number of new aux + * \param new_aux_handle pointer on array to store the new aux handles + * \param new_aux_names_handle pointer on array to store the new aux names */ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, const char* backend_name, @@ -2178,6 +2197,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, const mx_uint num_options, const char** keys, const char** vals, + const uint32_t num_input_shapes, + const char** input_shape_names, + const int64_t* input_shape_data, + const uint32_t* input_shape_idx, + const uint32_t num_input_dtypes, + const char** input_dtype_names, + const int* input_dtypes, + const uint32_t num_input_stypes, + const char** input_stype_names, + const int* input_stypes, + bool skip_infer, int* new_args_cnt, NDArrayHandle** new_args_handle, char*** new_arg_names_handle, diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index 9602b08d675e..59346ef16a95 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -1637,6 +1637,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, const mx_uint in, const char** keys, const char** vals, + const uint32_t num_input_shapes, + const char** input_shape_names, + const int64_t* input_shape_data, + const uint32_t* input_shape_idx, + const uint32_t num_input_dtypes, + const char** input_dtype_names, + const int* input_dtypes, + const uint32_t num_input_stypes, + const char** input_stype_names, + const int* input_stypes, + bool skip_infer, int* new_args_cnt, NDArrayHandle** new_args_handle, char*** new_arg_names_handle, diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index d7afd8a787b4..9772e2394486 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -949,41 +949,70 @@ def _build_cache(self, *args): warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - args, _ = _flatten(args, "input") try: - for is_arg, i in self._cached_op_args: - if not is_arg: - i.data() + for name in input_names: + if name in params: + params[name].data() except DeferredInitializationError: self._deferred_infer_shape(*args) - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() + for name in input_names: + if name in params: + params[name]._finish_deferred_init() + arg_dict, aux_dict = dict(), dict() if self._backend: ctx = args[0].context # get list of params in the order of out.list_arguments - arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()] - aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()] + arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_arguments()}) + aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_auxiliary_states()}) # Partition the graph. - out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts) + out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + #update cached graph with partitioned graph self._cached_graph = data, out + + input_names = out.list_inputs() + data_indices = [] + param_indices = [] + + # In the default case, _cached_ops_args contains all the parameters from params (the sets are identical) + # In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params, + # might contain some new parameters created during optimization and added to `arg_dict/aux_dict`, + # and might not contain some parameters that were deleted during optimization. + self._cached_op_args = [] + for i, name in enumerate(input_names): + pair = None + if name in data_names: + data_indices.append(i) + pair = (True, data_names[name]) + else: + param_indices.append(i) + if name in params: + param = params[name] + else: + # The param is missing from the original params dictionary, which means the param must have + # been added by the Partition API backend + if name in arg_dict or name: + param_data = arg_dict[name] + elif name in aux_dict: + param_data = aux_dict[name] + else: + raise RuntimeError('A parameter was added to the graph during optimization but it was not ' + 'added to the parameter dicts.\n' + 'Please check the backend.') + + param = Parameter(name) + param._load_init(param_data, args[0].context) + pair = (False, param) + + self._cached_op_args.append(pair) + + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ + self._flags + self._cached_op = ndarray.CachedOp(out, flags) @@ -1203,12 +1232,14 @@ def export(self, path, epoch=0, remove_amp_cast=True): arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for name, param in self.collect_params().items(): - if name in arg_names: - arg_dict['arg:%s'%name] = param._reduce() - else: - assert name in aux_names - arg_dict['aux:%s'%name] = param._reduce() + for is_arg, param in self._cached_op_args: + if not is_arg: + name = param.name + if name in arg_names: + arg_dict['arg:{}'.format(name)] = param._reduce() + else: + assert name in aux_names + arg_dict['aux:{}'.format(name)] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn('%s-%04d.params'%(path, epoch), arg_dict) @@ -1479,6 +1510,23 @@ def cast(self, dtype): def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError + def reset_ctx(self, ctx): + """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. + Parameters + ---------- + ctx : Context or list of Context, default :py:meth:`context.current_context()`. + Assign Parameter to given context. If ctx is a list of Context, a + copy will be made for each context. + """ + params = self.collect_params() + if self._cached_op: + for p in self._cached_op_args: + # resetting parameters creating by the partitioning backend + if p.name not in params: + p.reset_ctx(ctx) + for p in params.values(): + p.reset_ctx(ctx) + def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t): """Utility function that helps in inferring DType of args and auxs params from given input param. diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index e90fb9b31e36..0f8cccd071cf 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1446,7 +1446,8 @@ def _gen_atomic_symbol(self): # pylint: disable=too-many-locals - def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): + def optimize_for(self, backend, args=None, aux=None, ctx=None, + shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs): """Partitions current symbol and optimizes it for a given backend, returns new partitioned symbol. @@ -1455,23 +1456,35 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): backend : str The name of backend, as registered in `SubgraphBackendRegistry` - args : list of NDArray or dict of str to NDArray, optional + args : dict of str to NDArray, optional Input arguments to the symbol, required to infer shapes/types before partitioning - - - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`. - If type is a dict of str to `NDArray`, then it maps the name of arguments - to the corresponding `NDArray`. + to the corresponding `NDArray`. Non defined arguments' `NDArray`s don't have to be + specified in the dict. - aux : list of NDArray or dict of str to NDArray, optional + aux : dict of str to NDArray, optional Input auxiliary arguments to the symbol - - - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`. - If type is a dict of str to `NDArray`, then it maps the name of arguments to the corresponding `NDArray`. ctx : Context, optional Device context, used to infer stypes + shape_dict : Dict of str->tuple, optional + Input shape dictionary. + Used iff input NDArray is not in `args`. + + type_dict : Dict of str->numpy.dtype, optional + Input type dictionary. + Used iff input NDArray is not in `args`. + + stype_dict : Dict of str->str, optional + Input storage type dictionary. + Used iff input NDArray is not in `args`. + + skip_infer : bool, optional + If True, the optimization skips the shape, type and storage type inference pass. + kwargs : optional arguments Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty` @@ -1482,24 +1495,86 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): """ out = SymbolHandle() assert isinstance(backend, str) + assert isinstance(args, dict) or args is None + assert isinstance(aux, dict) or aux is None if args is None or len(args) == 0: args_ = [] args_handle = c_array(NDArrayHandle, []) else: args_handle, args_ = self._get_ndarray_inputs('args', args, - self.list_arguments(), False) + self.list_arguments(), True) if aux is None or len(aux) == 0: aux_ = [] aux_handle = c_array(NDArrayHandle, []) else: aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux, - self.list_auxiliary_states(), False) + self.list_auxiliary_states(), True) if ctx is None: ctx = current_context() assert isinstance(ctx, Context) + + # parse input data shape dict + num_input_shapes = 0 + input_shape_names = ctypes.POINTER(ctypes.c_char_p)() + input_shape_data = ctypes.POINTER(mx_int64)() + input_shape_idx = ctypes.POINTER(mx_uint)() + if shape_dict is not None: + input_shape_names = [] + input_shape_data = [] + input_shape_idx = [0] + for k, v in shape_dict.items(): + if isinstance(v, (tuple, list)): + input_shape_names.append(k) + input_shape_data.extend(v) + input_shape_idx.append(len(input_shape_data)) + else: + raise ValueError(str(v) + " has to be a tuple or list.") + num_input_shapes = mx_uint(len(input_shape_names)) + input_shape_names = c_str_array(input_shape_names) + input_shape_data = c_array_buf(mx_int64, array('q', input_shape_data)) + input_shape_idx = c_array_buf(mx_uint, array('i', input_shape_idx)) + + # parse input data types dict + num_input_types = 0 + input_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names + input_type_data = ctypes.POINTER(mx_uint)() # provided types + if type_dict is not None: + input_type_names = [] + input_type_data = [] + for k, v in type_dict.items(): + v = _numpy.dtype(v).type + if v in _DTYPE_NP_TO_MX: + input_type_names.append(k) + input_type_data.append(_DTYPE_NP_TO_MX[v]) + else: + raise ValueError(str(v) + " is not a MXNet type.") + + num_input_types = mx_uint(len(input_type_names)) + input_type_names = c_str_array(input_type_names) + input_type_data = c_array_buf(ctypes.c_int, array('i', input_type_data)) + + # parse input data storage types dict + num_input_stypes = 0 + # provided storage type argument names + input_stype_names = ctypes.POINTER(ctypes.c_char_p)() + input_stype_data = ctypes.POINTER(mx_uint)() # provided storage types + if stype_dict is not None: + input_stype_names = [] + input_stype_data = [] + for k, v in stype_dict.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + input_stype_names.append(k) + input_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v]) + else: + raise ValueError(str(v) + " is not a MXNet storage type.") + + num_input_stypes = mx_uint(len(input_stype_names)) + input_stype_names = c_str_array(input_stype_names) + input_stype_data = c_array_buf(ctypes.c_int, array('i', input_stype_data)) + new_args_size = ctypes.c_uint() new_arg_names = ctypes.POINTER(ctypes.c_char_p)() new_args_handle = ctypes.POINTER(NDArrayHandle)() @@ -1523,37 +1598,68 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): mx_uint(len(key_list)), c_str_array(key_list), c_str_array(val_list), + num_input_shapes, + input_shape_names, + input_shape_data, + input_shape_idx, + num_input_types, + input_type_names, + input_type_data, + num_input_stypes, + input_stype_names, + input_stype_data, + ctypes.c_bool(skip_infer), ctypes.byref(new_args_size), ctypes.byref(new_args_handle), ctypes.byref(new_arg_names), ctypes.byref(new_aux_size), ctypes.byref(new_aux_handle), ctypes.byref(new_aux_names))) - arg_names = self.list_arguments() - if isinstance(args, dict): + # add new args/aux + if not args is None: for i in range(new_args_size.value): args[py_str(new_arg_names[i])] = NDArray(NDArrayHandle(new_args_handle[i])) - elif isinstance(args, list): - for i in range(new_args_size.value): - name = py_str(new_arg_names[i]) - if name in arg_names: - idx = arg_names.index(name) - args[idx] = NDArray(NDArrayHandle(new_args_handle[i])) - else: - args.append(NDArray(NDArrayHandle(new_args_handle[i]))) - aux_names = self.list_auxiliary_states() - if isinstance(aux, dict): + elif new_args_size.value > 0: + raise RuntimeError('Cannot add new args in optimize_for since args is None\n' + + 'Provide a dictionary to the args argument to optimize_for') + + if not aux is None: for i in range(new_aux_size.value): aux[py_str(new_aux_names[i])] = NDArray(NDArrayHandle(new_aux_handle[i])) - elif isinstance(aux, list): - for i in range(new_aux_size.value): - name = py_str(new_aux_names[i]) - if name in aux_names: - idx = aux_names.index(name) - aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i])) - else: - aux.append(NDArray(NDArrayHandle(new_aux_handle[i]))) - return Symbol(out) + elif new_aux_size.value > 0: + raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' + + 'Provide a dictionary to the aux argument to optimize_for') + + new_sym = Symbol(out) + + arg_names = self.list_arguments() + new_arg_names = new_sym.list_arguments() + deleted_arg_names = set([item for item in arg_names + if item not in set(new_arg_names)]) + + if len(deleted_arg_names) > 0: + if args is not None: + for a_n in deleted_arg_names: + if a_n in args: + args.pop(a_n) + else: + warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' + + 'Please ensure that your model weights match the newly optimized model.') + + aux_names = self.list_auxiliary_states() + new_aux_names = new_sym.list_auxiliary_states() + deleted_aux_names = set([item for item in aux_names + if item not in set(new_aux_names)]) + if len(deleted_aux_names) > 0: + if aux is not None: + for a_n in deleted_aux_names: + if a_n in aux: + aux.pop(a_n) + else: + warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' + + 'Please ensure that your model weights match the newly optimized model.') + + return new_sym # pylint: disable=too-many-locals diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 3b3d83cbd6c2..29a773da0ad7 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1350,6 +1350,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, const mx_uint num_options, const char** keys, const char** vals, + const uint32_t num_input_shapes, + const char** input_shape_names, + const int64_t* input_shape_data, + const uint32_t* input_shape_idx, + const uint32_t num_input_dtypes, + const char** input_dtype_names, + const int* input_dtypes, + const uint32_t num_input_stypes, + const char** input_stype_names, + const int* input_stypes, + bool skip_infer, int* new_args_cnt, NDArrayHandle** new_args_handle, char*** new_arg_names_handle, @@ -1373,47 +1384,80 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, if (args_len || aux_len) { NDArray **in_args_ptr = reinterpret_cast(in_args_handle); NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); - Context default_ctx = Context::Create(static_cast(dev_type), 0); - mxnet::ShapeVector arg_shapes(args_len + aux_len); - nnvm::DTypeVector arg_dtypes(args_len + aux_len); - StorageTypeVector arg_stypes(args_len + aux_len); - size_t args_top = 0, aux_top = 0; - // loop over inputs to symbol in order and add to args/aux if mutable - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = indexed_graph.input_nodes().at(i); - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_len) - << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; - const auto &in_arg = *(in_aux_ptr[aux_top++]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); - } else { - CHECK_LT(args_top, args_len) - << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for"; - const auto &in_arg = *(in_args_ptr[args_top++]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); + if (!skip_infer) { + Context default_ctx = Context::Create(static_cast(dev_type), 0); + mxnet::ShapeVector arg_shapes(args_len + aux_len); + nnvm::DTypeVector arg_dtypes(args_len + aux_len); + StorageTypeVector arg_stypes(args_len + aux_len); + + // create the input shape, dtype and stype maps + std::unordered_map input_shape_map(num_input_shapes); + for (uint32_t i = 0; i < num_input_shapes; ++i) { + input_shape_map.emplace(input_shape_names[i], + mxnet::TShape(input_shape_data + input_shape_idx[i], + input_shape_data + input_shape_idx[i+1])); + } + std::unordered_map input_dtype_map(num_input_dtypes); + for (uint32_t i = 0; i < num_input_dtypes; ++i) { + input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]); + } + std::unordered_map input_stype_map(num_input_stypes); + for (uint32_t i = 0; i < num_input_stypes; ++i) { + input_stype_map.emplace(input_stype_names[i], input_stypes[i]); } - } - g.attrs["context"] = std::make_shared( - exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); + size_t args_top = 0, aux_top = 0; + // loop over inputs to symbol in order and add to args/aux if mutable + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = indexed_graph.input_nodes().at(i); + if (mutable_nodes.count(nid)) { + CHECK_LT(aux_top, aux_len) + << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; + if (in_aux_ptr[aux_top] != nullptr) { + const auto &in_arg = *(in_aux_ptr[aux_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); + } + aux_top++; + } else { + auto name = input_names[i]; + CHECK_LT(args_top, args_len) + << "Cannot find arg '" << name << "' in provided args to optimize_for"; + if (in_args_ptr[args_top] != nullptr) { + const auto &in_arg = *(in_args_ptr[args_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); + } else { + // input_names[i] is not in args but can be in the optional + // shape/type/stype attribute dicts. + auto it_shape = input_shape_map.find(name); + if (it_shape != input_shape_map.end()) { + arg_shapes[i] = it_shape->second; + } + auto it_type = input_dtype_map.find(name); + if (it_type != input_dtype_map.end()) { + arg_dtypes[i] = it_type->second; + } + it_type = input_stype_map.find(name); + if (it_type != input_stype_map.end()) { + arg_stypes[i] = it_type->second; + } + } + args_top++; + } + } - // infer shapes - g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); - // infer dtypes - g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); - if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { - common::HandleInferTypeError(num_forward_inputs, indexed_graph, - g.GetAttr("dtype")); - } - // infer stypes - g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); - if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { - common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph, - g.GetAttr("storage_type")); + g.attrs["context"] = std::make_shared( + exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); + + // infer shapes + g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + // infer dtypes + g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + // infer stypes + g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); } // set args/aux as attributes on graph so that subgraph property can use them std::vector arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 2d5501d26f86..7cf9671b92b4 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -226,9 +226,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector std::stack s; s.push(descendant); size_t count = 0; - while (!s.empty()) { - CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably" - " a loop in the graph"; + while (!s.empty() && count < indexed_graph.num_nodes()) { ++count; const nnvm::Node* top = s.top(); s.pop(); @@ -276,10 +274,6 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector if (excluded_node_id != -1) { CHECK_LT(excluded_node_id, static_cast(simple_nodes.size())); - CHECK_NE(excluded_node_id, static_cast(snid)) - << "A cycle is found in the computational graph between nodes " - << simple_nodes[excluded_node_id]->node->attrs.name << " and " - << simple_nodes[snid]->node->attrs.name; excluded_nodes->insert(simple_nodes[excluded_node_id].get()); ResetNodeLabels(g, simple_nodes, subgraph_nodes); return false; @@ -306,6 +300,7 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph const std::vector& simple_nodes, std::vector* subgraph_nodes) { std::unordered_set excluded_nodes; + size_t n_excluded_nodes = 0; const size_t max_num_retry = simple_nodes.size() * simple_nodes.size(); size_t count = 0; bool success = false; @@ -313,7 +308,14 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes, subgraph_nodes, &excluded_nodes); if (!success) { - CHECK(!excluded_nodes.empty()); + // Failed to label subgraph due to a cycle + // If the number of excluded_nodes didn't change since the last iteration, + // this means that there is no possible subgraph for the current node snid, we break + // Otherwise, we keep trying (with the excluded nodes tagged) + if (excluded_nodes.size() == n_excluded_nodes) { + break; + } + n_excluded_nodes = excluded_nodes.size(); std::string excluded_node_names; for (auto node : excluded_nodes) { excluded_node_names += node->node->attrs.name + ", "; @@ -428,7 +430,7 @@ void SortEntries(const std::unordered_map& entry } /*! - * \brief Given a subgraph, find the output entries of a subgraph. + * \brief Given a subgraph, find the input entries of a subgraph. * \param g pointer to the whole graph * \param simple_nods vector of simple nodes in top sorted order * \param subgraph_nodes vector of pointers of simples of a subgraph. diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc index 19d0f26eae66..4f80d277cad8 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc @@ -31,7 +31,6 @@ #include #include #include -#include #include "../../../common/utils.h" #include "../../../ndarray/ndarray_function.h" @@ -39,6 +38,7 @@ #include "../../nn/activation-inl.h" #include "../../nn/batch_norm-inl.h" #include "../../nn/convolution-inl.h" +#include "../../nn/deconvolution-inl.h" #include "../../nn/fully_connected-inl.h" #include "../../nn/pooling-inl.h" #include "../../nn/concat-inl.h" diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc index b02d1094183f..4f5bdcb8561c 100644 --- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc +++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc @@ -35,13 +35,9 @@ #include #include #include -#include #include #include -#include -#include - using std::cout; using std::cerr; using std::endl; @@ -78,7 +74,9 @@ std::tuple, auto trt_logger = std::unique_ptr(new TRT_Logger(verbosity)); auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger)); - auto trt_network = InferObject(trt_builder->createNetwork()); + const auto explicitBatch = 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch)); auto trt_parser = InferObject(nvonnxparser::createParser(*trt_network, *trt_logger)); ::ONNX_NAMESPACE::ModelProto parsed_model; // We check for a valid parse, but the main effect is the side effect diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h index dcafba55959d..369d7c3d397b 100644 --- a/src/operator/subgraph/tensorrt/tensorrt-inl.h +++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h @@ -268,6 +268,23 @@ class TensorrtProperty : public SubgraphProperty { return std::make_shared(); } + void PrePartition(const nnvm::Graph& g, + const std::vector>& options_map) override { + auto& in_arg_names = g.GetAttr>("in_arg_names"); + auto& in_aux_names = g.GetAttr>("in_aux_names"); + NDArray **in_args_ptr = g.GetAttr("in_args"); + NDArray **in_aux_ptr = g.GetAttr("in_aux"); + in_args_dict.clear(); + in_aux_dict.clear(); + // we trust the Python API, len(in_arg_names) == len(in_args_ptr) + for (unsigned i = 0; i < in_arg_names.size(); ++i) { + in_args_dict[in_arg_names[i]] = in_args_ptr[i]; + } + for (unsigned i = 0; i < in_aux_names.size(); ++i) { + in_aux_dict[in_aux_names[i]] = in_aux_ptr[i]; + } + } + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id) const override { nnvm::ObjectPtr n = nnvm::Node::Create(); @@ -281,16 +298,33 @@ class TensorrtProperty : public SubgraphProperty { n->attrs.op = Op::Get("_TensorRT"); CHECK(n->attrs.op); n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + + // Mapping subgraph params with NDArrays + TRTParam param; std::ostringstream params_oss; - for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) { - params_oss << e << ";"; + for (auto ¶m_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) { + NDArray *cache = nullptr; + auto it_args = in_args_dict.find(param_name); + if (it_args != in_args_dict.end()) { + cache = it_args->second; + } else { + auto it_aux = in_aux_dict.find(param_name); + if (it_aux != in_aux_dict.end()) { + cache = it_aux->second; + } + } + if (cache != nullptr) { + param.params_map.emplace(param_name, cache->Copy(Context())); + param.params_map[param_name].WaitToRead(); + params_oss << param_name << ";"; + } } auto tensorrt_params_names = params_oss.str(); - tensorrt_params_names.pop_back(); - n->attrs.dict["subgraph_params_names"] = tensorrt_params_names; - TRTParam param; + if (!tensorrt_params_names.empty()) { + tensorrt_params_names.pop_back(); + } n->attrs.parsed = param; - n->op()->attr_parser(&(n->attrs)); + n->attrs.dict["subgraph_params_names"] = tensorrt_params_names; return n; } @@ -329,6 +363,8 @@ class TensorrtProperty : public SubgraphProperty { } subgraph_node->attrs.parsed = std::move(_params); } + + std::unordered_map in_args_dict, in_aux_dict; }; diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu b/src/operator/subgraph/tensorrt/tensorrt.cu index 4a5b23b3a9f7..826f9a5876b6 100644 --- a/src/operator/subgraph/tensorrt/tensorrt.cu +++ b/src/operator/subgraph/tensorrt/tensorrt.cu @@ -56,12 +56,12 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx, param.bindings->at(i) = outputs[p.first].dptr_; } } - const int batch_size = static_cast(inputs[0].shape_[0]); - param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s, nullptr); + param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr); } NNVM_REGISTER_OP(_TensorRT) -.set_attr("FStatefulCompute", TRTCompute); +.set_attr("FStatefulCompute", TRTCompute) +.set_attr("FGradient", MakeZeroGradNodes); } // namespace op } // namespace mxnet diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py deleted file mode 100755 index a0ea447de5a0..000000000000 --- a/tests/python/tensorrt/lenet5_train.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import mxnet as mx -import numpy as np - -def get_iters(mnist, batch_size): - """Get MNIST iterators.""" - train_iter = mx.io.NDArrayIter(mnist['train_data'], - mnist['train_label'], - batch_size, - shuffle=True) - val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) - test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) - all_test_labels = np.array(mnist['test_label']) - return train_iter, val_iter, test_iter, all_test_labels - -def lenet5(): - """LeNet-5 Symbol""" - #pylint: disable=no-member - data = mx.sym.Variable('data') - data = mx.sym.Cast(data, 'float16') - conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20) - tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") - pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", - kernel=(2, 2), stride=(2, 2)) - # second conv - conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50) - tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") - pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", - kernel=(2, 2), stride=(2, 2)) - # first fullc - flatten = mx.sym.Flatten(data=pool2) - fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500) - tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") - # second fullc - fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) - fc2 = mx.sym.Cast(fc2, 'float32') - # loss - lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') - #pylint: enable=no-member - return lenet - - -def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter): - """train LeNet-5 model on MNIST data""" - ctx = mx.gpu(0) - lenet_model = mx.mod.Module(lenet5(), context=ctx) - - lenet_model.fit(train_iter, - eval_data=val_iter, - optimizer='sgd', - optimizer_params={'learning_rate': 0.1, 'momentum': 0.9}, - eval_metric='acc', - batch_end_callback=mx.callback.Speedometer(batch_size, 1), - num_epoch=num_epochs) - - # predict accuracy for lenet - acc = mx.metric.Accuracy() - lenet_model.score(test_iter, acc) - accuracy = acc.get()[1] - assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low" - return lenet_model - - -if __name__ == '__main__': - num_epochs = 10 - batch_size = 128 - model_name = 'lenet5' - model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") - model_file = '%s/%s-symbol.json' % (model_dir, model_name) - params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) - - if not (os.path.exists(model_file) and os.path.exists(params_file)): - mnist = mx.test_utils.get_mnist() - - _, _, _, all_test_labels = get_iters(mnist, batch_size) - - trained_lenet = train_lenet5(num_epochs, batch_size, - *get_iters(mnist, batch_size)[:-1]) - trained_lenet.save_checkpoint(model_name, num_epochs) diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py deleted file mode 100644 index 4b8eb48e926c..000000000000 --- a/tests/python/tensorrt/test_cvnets.py +++ /dev/null @@ -1,174 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import gc -import gluoncv -import mxnet as mx -import numpy as np - -from mxnet import gluon -from time import time - -from mxnet.gluon.data.vision import transforms - - -def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128): - mx.contrib.tensorrt.set_use_fp16(False) - h, w = 32, 32 - net = gluoncv.model_zoo.get_model(model_name, pretrained=True) - net.hybridize() - net.forward(mx.nd.zeros((batch_size, 3, h, w))) - net.export(model_name) - _sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0) - if use_tensorrt: - sym = _sym.get_backend_symbol('TensorRT') - arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(sym, arg_params, - aux_params) - else: - sym = _sym - executor = sym.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), - softmax_label=(batch_size,), - grad_req='null', force_rebind=True) - executor.copy_params_from(arg_params, aux_params) - return executor - - -def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), batch_size=128): - executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size) - - num_ex = 10000 - all_preds = np.zeros([num_ex, 10]) - - all_label_test = np.zeros(num_ex) - - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) - ]) - - data_loader = lambda: gluon.data.DataLoader( - gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), - batch_size=batch_size, shuffle=False, num_workers=num_workers) - - val_data = data_loader() - - for idx, (data, label) in enumerate(val_data): - # Skip last batch if it's undersized. - if data.shape[0] < batch_size: - continue - offset = idx * batch_size - all_label_test[offset:offset + batch_size] = label.asnumpy() - - # warm-up, but don't use result - executor.forward(is_train=False, data=data) - executor.outputs[0].wait_to_read() - - gc.collect() - val_data = data_loader() - example_ct = 0 - start = time() - - # if use_tensorrt: - for idx, (data, label) in enumerate(val_data): - # Skip last batch if it's undersized. - if data.shape[0] < batch_size: - continue - executor.forward(is_train=False, data=data) - preds = executor.outputs[0].asnumpy() - offset = idx * batch_size - all_preds[offset:offset + batch_size, :] = preds[:batch_size] - example_ct += batch_size - - all_preds = np.argmax(all_preds, axis=1) - matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum() - duration = time() - start - - return duration, 100.0 * matches / example_ct - - -def run_experiment_for(model_name, batch_size, num_workers): - print("\n===========================================") - print("Model: %s" % model_name) - print("===========================================") - print("*** Running inference using pure MXNet ***\n") - mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, - num_workers=num_workers, use_tensorrt=False) - print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct)) - print("\n*** Running inference using MXNet + TensorRT ***\n") - trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, - num_workers=num_workers, use_tensorrt=True) - print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct)) - speedup = mx_duration / trt_duration - print("TensorRT speed-up (not counting compilation): %.2fx" % speedup) - - acc_diff = abs(mx_pct - trt_pct) - print("Absolute accuracy difference: %f" % acc_diff) - return speedup, acc_diff - - -def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1): - original_use_fp16 = mx.contrib.tensorrt.get_use_fp16() - try: - models = [ - 'cifar_resnet20_v1', - 'cifar_resnet56_v1', - 'cifar_resnet110_v1', - 'cifar_resnet20_v2', - 'cifar_resnet56_v2', - 'cifar_resnet110_v2', - 'cifar_wideresnet16_10', - 'cifar_wideresnet28_10', - 'cifar_wideresnet40_8', - 'cifar_resnext29_16x64d' - ] - - num_models = len(models) - - speedups = np.zeros(num_models, dtype=np.float32) - acc_diffs = np.zeros(num_models, dtype=np.float32) - - test_start = time() - - for idx, model in enumerate(models): - speedup, acc_diff = run_experiment_for(model, batch_size, num_workers) - speedups[idx] = speedup - acc_diffs[idx] = acc_diff - assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % ( - tolerance, model) - - print("Perf and correctness checks run on the following models:") - print(models) - mean_speedup = np.mean(speedups) - std_speedup = np.std(speedups) - print("\nSpeedups:") - print(speedups) - print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) - print("Mean speedup: %.2f" % mean_speedup) - print("St. dev. of speedups: %.2f" % std_speedup) - print("\nAcc. differences: %s" % str(acc_diffs)) - - test_duration = time() - test_start - - print("Test duration: %.2f seconds" % test_duration) - finally: - mx.contrib.tensorrt.set_use_fp16(original_use_fp16) - - -if __name__ == '__main__': - import nose - - nose.runmodule() diff --git a/tests/python/tensorrt/test_ops.py b/tests/python/tensorrt/test_ops.py deleted file mode 100644 index af1c453111d9..000000000000 --- a/tests/python/tensorrt/test_ops.py +++ /dev/null @@ -1,517 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import mxnet as mx -import numpy as np -from itertools import product -import copy - -from numpy.testing import assert_allclose - -import sys -import os -curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) -sys.path.insert(0, os.path.join(curr_path, '../unittest')) -from common import setup_module, with_seed - -def check_unsupported_single_sym(sym): - wrapped_sym = mx.sym.Group([mx.sym.identity(s) for s in sym]) - trt_sym = wrapped_sym.get_backend_symbol('TensorRT') - assert len(wrapped_sym.get_internals()) == len(trt_sym.get_internals()) - -def check_single_sym(sym, data_shapes, arg_params_shapes=None, aux_params_shapes=None, - rtol_fp32=1e-5, atol_fp32=0., rtol_fp16=1e-3, atol_fp16=0.): - if arg_params_shapes is None: - arg_params_shapes = {} - if aux_params_shapes is None: - aux_params_shapes = {} - for i in range(3): - data = {k: mx.nd.array(np.random.rand(*v) + 0.01, dtype='float32', ctx=mx.cpu()) - for k, v in data_shapes.items()} - arg_params = {k: mx.nd.array(np.random.rand(*v) + 0.01, dtype='float32', ctx=mx.cpu()) - for k, v in arg_params_shapes.items()} - aux_params = {k: mx.nd.array(np.random.rand(*v) + 0.01, dtype='float32', ctx=mx.cpu()) - for k, v in aux_params_shapes.items()} - wrapped_sym = mx.sym.Group([mx.sym.identity(s) for s in sym]) - - # Test FP32 MXNet Native - shapes = {} - shapes.update(data_shapes) - shapes.update(arg_params_shapes) - shapes.update(aux_params_shapes) - orig_executor = wrapped_sym.simple_bind(ctx=mx.gpu(0), grad_req='null', - force_rebind=True, **shapes) - orig_executor.copy_params_from(arg_params, aux_params) - orig_executor.forward(is_train=False, **data) - orig_outputs = [arr.asnumpy() for arr in orig_executor.outputs] - - # Test FP32 MXNet-TRT - mx.contrib.tensorrt.set_use_fp16(False) - trt_sym = wrapped_sym.get_backend_symbol('TensorRT') - assert len(trt_sym.get_internals()) < len(wrapped_sym.get_internals()) - remaining_arg_params, remaining_aux_params = \ - mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params) - shapes = {} - shapes.update(data_shapes) - shapes.update({k: v.shape for k, v in remaining_arg_params.items()}) - shapes.update({k: v.shape for k, v in remaining_aux_params.items()}) - trt_fp32_executor = trt_sym.simple_bind(ctx=mx.gpu(0), grad_req='null', - force_rebind=True, **shapes) - trt_fp32_executor.copy_params_from(remaining_arg_params, remaining_aux_params) - trt_fp32_executor.forward(is_train=False, **data) - trt_fp32_outputs = [arr.asnumpy() for arr in trt_fp32_executor.outputs] - - # Test FP16 MXNet-TRT - mx.contrib.tensorrt.set_use_fp16(True) - data = {k: v.astype('float16') for k, v in data.items()} - arg_params = {k: v.astype('float16') for k, v in arg_params.items()} - aux_params = {k: v.astype('float16') for k, v in aux_params.items()} - trt_sym = wrapped_sym.get_backend_symbol('TensorRT') - assert len(trt_sym.get_internals()) < len(wrapped_sym.get_internals()) - remaining_arg_params, remaining_aux_params = \ - mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params) - shapes = {} - shapes.update(data_shapes) - shapes.update({k: v.shape for k, v in remaining_arg_params.items()}) - shapes.update({k: v.shape for k, v in remaining_aux_params.items()}) - - trt_fp16_executor = trt_sym.simple_bind(ctx=mx.gpu(0), - type_dict={k: 'float16' for k in shapes.keys()}, - grad_req='null', force_rebind=True, **shapes) - trt_fp16_executor.copy_params_from(remaining_arg_params, remaining_aux_params) - trt_fp16_executor.forward(is_train=False, **data) - trt_fp16_outputs = [arr.asnumpy() for arr in trt_fp16_executor.outputs] - - for j, (orig, fp16, fp32) in enumerate(zip(orig_outputs, trt_fp16_outputs, trt_fp32_outputs)): - abs_orig = abs(orig) - diff32 = abs(fp32 - orig) - diff16 = abs(fp16.astype('float32') - orig) - _atol32 = diff32 - rtol_fp32 * abs_orig - _atol16 = diff16 - rtol_fp16 * abs_orig - print("{}: diff32({:.2E}) | diff16({:.2E}) | atol32({:.2E}) | atol16({:.2E}) | orig.min({:.2E})".format( - j, diff32.max(), diff16.max(), _atol32.max(), _atol16.max(), abs_orig.min())) - assert_allclose(fp32, orig, rtol=rtol_fp32, atol=atol_fp32) - assert_allclose(fp16, orig, rtol=rtol_fp16, atol=atol_fp16) - -@with_seed() -def test_noop(): - data = mx.sym.Variable('data') - check_unsupported_single_sym(data) - - -@with_seed() -def test_identity(): - data = mx.sym.Variable('data') - sym = mx.sym.identity(data) - check_single_sym(sym, data_shapes={'data': (8,3,32,32)}, - rtol_fp32=0., atol_fp32=0., rtol_fp16=1e-3, atol_fp16=1e-7) - - -@with_seed() -def test_convolution2d(): - data = mx.sym.Variable('data') - weight = mx.sym.Variable('weight') - bias = mx.sym.Variable('bias') - data_shape = (8,3,16,16) - num_filter = 7 - for kernel in [(3, 3), (1, 1), (3, 1)]: - for stride in [(1, 1), (2, 2), (2, 1)]: - if stride[0] > kernel[0] or stride[1] > kernel[1]: # doesn't make any sense - continue - if kernel == (3, 3) and stride == (1, 1): - atol_fp32 = 0. - rtol_fp32 = 1e-5 - atol_fp16 = 0. - rtol_fp16 = 1e-2 - else: - atol_fp32 = 0. - rtol_fp32 = 0. - atol_fp16 = 0. - rtol_fp16 = 1e-2 - for pad in [(1, 1), (0, 0), (1, 0)]: - for group in [1, 2]: - for layout in ['NCHW', 'NHWC']: - weight_shape = (num_filter, data_shape[1]) + kernel - bias_shape = (num_filter,) - sym = mx.sym.Convolution(data, weight=weight, bias=bias, kernel=kernel, - stride=stride, pad=pad, num_filter=num_filter, - no_bias=False, layout=layout) - if layout == 'NCHW': - print("kernel: {} | stride: {} | pad: {} | group: {} | layout: {} | with_bias".format( - kernel, stride, pad, group, layout)) - check_single_sym(sym, {'data': data_shape}, - {'weight': weight_shape, 'bias': bias_shape}, - rtol_fp32=rtol_fp32, atol_fp32=atol_fp32, - rtol_fp16=rtol_fp16, atol_fp16=atol_fp16) - else: - check_unsupported_single_sym(sym) - sym = mx.sym.Convolution(data, weight=weight, kernel=kernel, stride=stride, - pad=pad, num_filter=num_filter, no_bias=True, - layout=layout) - if layout == 'NCHW': - print("kernel: {} | stride: {} | pad: {} | group: {} | layout: {} | without_bias".format( - kernel, stride, pad, group, layout)) - check_single_sym(sym, {'data': data_shape}, - {'weight': weight_shape}, - rtol_fp32=rtol_fp32, atol_fp32=atol_fp32, - rtol_fp16=rtol_fp16, atol_fp16=atol_fp16) - else: - check_unsupported_single_sym(sym) - -@with_seed() -def test_deconvolution2d(): - data = mx.sym.Variable('data') - weight = mx.sym.Variable('weight') - bias = mx.sym.Variable('bias') - data_shape = (8,3,16,16) - num_filter = 7 - for kernel in [(3, 3), (1, 1), (3, 1)]: - for stride in [(1, 1), (2, 2), (2, 1)]: - if stride[0] > kernel[0] or stride[1] > kernel[1]: # doesn't make any sense - continue - if kernel == (3, 3) and stride == (1, 1): - atol_fp32 = 0. - rtol_fp32 = 5e-5 - atol_fp16 = 0. - rtol_fp16 = 1e-2 - else: - atol_fp32 = 0. - rtol_fp32 = 1e-6 - atol_fp16 = 0. - rtol_fp16 = 1e-2 - for pad in [(1, 1), (0, 0), (1, 0)]: - for group in [1, 2]: - for layout in ['NCHW', 'NHWC']: - weight_shape = (data_shape[1], num_filter) + kernel - bias_shape = (num_filter,) - sym = mx.sym.Deconvolution(data, weight=weight, bias=bias, kernel=kernel, - stride=stride, pad=pad, num_filter=num_filter, - no_bias=False, layout=layout) - if layout == 'NCHW': - print("kernel: {} | stride: {} | pad: {} | group: {} | layout: {} | with_bias".format( - kernel, stride, pad, group, layout)) - check_single_sym(sym, {'data': data_shape}, - {'weight': weight_shape, 'bias': bias_shape}, - rtol_fp32=rtol_fp32, atol_fp32=atol_fp32, - rtol_fp16=rtol_fp16, atol_fp16=atol_fp16) - else: - check_unsupported_single_sym(sym) - sym = mx.sym.Deconvolution(data, weight=weight, kernel=kernel, stride=stride, - pad=pad, num_filter=num_filter, no_bias=True, - layout=layout) - if layout == 'NCHW': - print("kernel: {} | stride: {} | pad: {} | group: {} | layout: {} | without_bias".format( - kernel, stride, pad, group, layout)) - check_single_sym(sym, {'data': data_shape}, - {'weight': weight_shape}, - rtol_fp32=rtol_fp32, atol_fp32=atol_fp32, - rtol_fp16=rtol_fp16, atol_fp16=atol_fp16) - else: - check_unsupported_single_sym(sym) - -@with_seed() -def test_fully_connected(): # TODO(cfujitsang): take care of flatten option - data = mx.sym.Variable('data') - weight = mx.sym.Variable('weight') - bias = mx.sym.Variable('bias') - data_shape = (8,64) - num_hidden = 7 - weight_shape = (num_hidden, data_shape[1]) - bias_shape = (num_hidden,) - sym = mx.sym.FullyConnected(data, weight=weight, bias=bias, no_bias=False, - num_hidden=num_hidden) - check_single_sym(sym, {'data': data_shape}, {'weight': weight_shape, 'bias': bias_shape}, - rtol_fp16=5e-3, atol_fp16=0.) - sym = mx.sym.FullyConnected(data, weight=weight, no_bias=True, num_hidden=num_hidden) - check_unsupported_single_sym(sym) - - -@with_seed() -def test_relu(): - data = mx.sym.Variable('data') - sym = mx.sym.relu(data) - for data_shape in [(10, 32), (10, 3, 32), (10, 3, 32, 32), (10, 3, 7, 32, 32)]: - check_single_sym(sym, {'data': data_shape}, rtol_fp32=0., atol_fp32=0., - rtol_fp16=1e-3, atol_fp16=1e-7) - - -@with_seed() -def test_activation(): - data = mx.sym.Variable('data') - for act_type in ['relu', 'sigmoid', 'tanh']: - sym = mx.sym.Activation(data, act_type=act_type) - for data_shape in [(10, 32), (10, 3, 32), (10, 3, 32, 32), (10,3,7,32,32)]: - check_single_sym(sym, {'data': data_shape}, rtol_fp32=0., atol_fp32=0., - rtol_fp16=1e-3, atol_fp16=1e-7) - for act_type in ['softrelu', 'softsign']: - sym = mx.sym.Activation(data, act_type=act_type) - check_unsupported_single_sym(sym) - - -@with_seed() -def test_pooling2d(): - data = mx.sym.Variable('data') - data_shape = (4, 3, 32,32) - for pool_type in ['max', 'avg', 'lp', 'sum']: - if pool_type == 'max': - rtol_fp32 = 1e-6 - atol_fp32 = 0. - rtol_fp16 = 1e-3 - atol_fp16 = 0. - else: - rtol_fp32 = 5e-6 - atol_fp32 = 0. - rtol_fp16 = 1e-3 - atol_fp16 = 0. - for layout in ['NHWC', 'NCHW']: - for (stride, pad, kernel, count_include_pad, pooling_convention) \ - in product([(2,2), (2,1)], [(0,0), (1,1)], [(2,2), (3,2)], - [True, False], ['valid', 'full']): - print("pool_type: {} | layout: {} | stride: {} | pad: {} | ".format( - pool_type, layout, stride, pad) + - "kernel: {} | count_include_pad: {} | pooling_convention: {}".format( - kernel, count_include_pad, pooling_convention)) - sym = mx.sym.Pooling(data, kernel=kernel, pool_type=pool_type, stride=stride, - pad=pad, layout=layout, count_include_pad=count_include_pad, - pooling_convention=pooling_convention) - if (layout == 'NHWC') or \ - pool_type not in ('max', 'avg') or \ - pooling_convention != 'valid' or \ - (pool_type == 'avg' and count_include_pad): - check_unsupported_single_sym(sym) - else: - check_single_sym(sym, {'data': data_shape}, - rtol_fp32=rtol_fp32, atol_fp32=atol_fp32, - rtol_fp16=rtol_fp16, atol_fp16=atol_fp16) - print("pool_type: {} | layout: {} | global_pool".format(pool_type, layout)) - sym = mx.sym.Pooling(data, global_pool=True, pool_type=pool_type, layout=layout) - if layout == 'NHWC' or pool_type not in ('max', 'avg'): - check_unsupported_single_sym(sym) - else: - if pool_type == 'max': - rtol_fp32 = 0. - atol_fp32 = 0. - rtol_fp16 = 1e-3 - atol_fp16 = 0. - else: - rtol_fp32 = 1e-5 - atol_fp32 = 0. - rtol_fp16 = 1e-3 - atol_fp16 = 0. - check_single_sym(sym, {'data': data_shape}, rtol_fp32=rtol_fp32, - atol_fp32=atol_fp32, rtol_fp16=rtol_fp16, atol_fp16=atol_fp16) - - -@with_seed() -def test_softmax_output(): - data = mx.sym.Variable('data') - label = mx.sym.Variable('label') - data_shape = (8, 100) - label_shape = (8, 100) - sym = mx.sym.SoftmaxOutput(data, label) - check_single_sym(sym, {'data': data_shape, 'label': label_shape}, - rtol_fp32=1e-6, atol_fp32=0., rtol_fp16=5e-3, atol_fp16=0.) - sym = mx.sym.SoftmaxOutput(data) - check_single_sym(sym, {'data': data_shape}, - rtol_fp32=1e-6, atol_fp32=0., rtol_fp16=5e-3, atol_fp16=0.) - - - -def check_batch_norm(sym, data_shapes, arg_params_shapes=None, aux_params_shapes=None, - rtol_fp32=1e-5, atol_fp32=1e-7, rtol_fp16=1e-2, atol_fp16=1e-3): - if arg_params_shapes is None: - arg_params_shapes = {} - if aux_params_shapes is None: - aux_params_shapes = {} - for i in range(3): - data = { - 'data': mx.nd.array(np.random.rand(*data_shapes['data']) + 0.01, - dtype='float32', ctx=mx.cpu()) - } - arg_params = { - 'gamma': mx.nd.array(np.random.rand(*arg_params_shapes['gamma']) * 0.1 + 1., - dtype='float32', ctx=mx.cpu()), - 'beta': mx.nd.array(np.random.rand(*arg_params_shapes['beta']), - dtype='float32', ctx=mx.cpu()) - } - aux_params = { - 'moving_mean': mx.nd.array( - 0.45 + np.random.rand(*aux_params_shapes['moving_mean']) * 0.1 + 0.01, - dtype='float32', ctx=mx.cpu()), - 'moving_var': mx.nd.array( - 0.95 + np.random.rand(*aux_params_shapes['moving_var']) * 0.1, - dtype='float32', ctx=mx.cpu()) - } - wrapped_sym = mx.sym.Group([mx.sym.identity(s) for s in sym]) - - # Test FP32 MXNet Native - shapes = {} - shapes.update(data_shapes) - shapes.update(arg_params_shapes) - shapes.update(aux_params_shapes) - orig_executor = wrapped_sym.simple_bind(ctx=mx.gpu(0), grad_req='null', - force_rebind=True, **shapes) - orig_executor.copy_params_from(arg_params, aux_params) - orig_executor.forward(is_train=False, **data) - orig_outputs = [arr.asnumpy() for arr in orig_executor.outputs] - - # Test FP32 MXNet-TRT - mx.contrib.tensorrt.set_use_fp16(False) - trt_sym = wrapped_sym.get_backend_symbol('TensorRT') - assert len(trt_sym.get_internals()) < len(wrapped_sym.get_internals()) - remaining_arg_params, remaining_aux_params = \ - mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params) - shapes = {} - shapes.update(data_shapes) - shapes.update({k: v.shape for k, v in remaining_arg_params.items()}) - shapes.update({k: v.shape for k, v in remaining_aux_params.items()}) - trt_fp32_executor = trt_sym.simple_bind(ctx=mx.gpu(0), grad_req='null', - force_rebind=True, **shapes) - trt_fp32_executor.copy_params_from(remaining_arg_params, remaining_aux_params) - trt_fp32_executor.forward(is_train=False, **data) - trt_fp32_outputs = [arr.asnumpy() for arr in trt_fp32_executor.outputs] - - # Test FP16 MXNet-TRT - mx.contrib.tensorrt.set_use_fp16(True) - data = {k: v.astype('float16') for k, v in data.items()} - arg_params = {k: v.astype('float32') for k, v in arg_params.items()} - aux_params = {k: v.astype('float32') for k, v in aux_params.items()} - trt_sym = wrapped_sym.get_backend_symbol('TensorRT') - remaining_arg_params, remaining_aux_params = \ - mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params) - shapes = {} - shapes.update(data_shapes) - shapes.update({k: v.shape for k, v in remaining_arg_params.items()}) - shapes.update({k: v.shape for k, v in remaining_aux_params.items()}) - - trt_fp16_executor = trt_sym.simple_bind(ctx=mx.gpu(0), - type_dict={k: 'float16' for k in shapes.keys()}, - grad_req='null', force_rebind=True, **shapes) - trt_fp16_executor.copy_params_from(remaining_arg_params, remaining_aux_params) - trt_fp16_executor.forward(is_train=False, **data) - trt_fp16_outputs = [arr.asnumpy() for arr in trt_fp16_executor.outputs] - - - for j, (orig, fp16, fp32) in enumerate(zip(orig_outputs, - trt_fp16_outputs, - trt_fp32_outputs)): - abs_orig = abs(orig) - diff32 = abs(fp32 - orig) - diff16 = abs(fp16.astype('float32') - orig) - _atol32 = diff32 - rtol_fp32 * abs_orig - _atol16 = diff16 - rtol_fp16 * abs_orig - print("{}: diff32({:.2E}) | diff16({:.2E}) | atol32({:.2E}) | atol16({:.2E}) | orig.min({:.2E})".format( - j, diff32.max(), diff16.max(), _atol32.max(), _atol16.max(), abs_orig.min())) - assert_allclose(fp32, orig, rtol=rtol_fp32, atol=atol_fp32) - assert_allclose(fp16.astype('float32'), orig, rtol=rtol_fp16, atol=atol_fp16) - -@with_seed() -def test_batch_norm(): - data = mx.sym.Variable('data') - gamma = mx.sym.Variable('gamma') - beta = mx.sym.Variable('beta') - moving_mean = mx.sym.Variable('moving_mean') - moving_var = mx.sym.Variable('moving_var') - data_shape = (4,3,32,32) - gamma_shape = (3,) - beta_shape = (3,) - moving_mean_shape = (3,) - moving_var_shape = (3,) - for fix_gamma in [True, False]: - for use_global_stats in [True, False]: - for axis in [0, 1, 2, 3]: - sym = mx.sym.BatchNorm(data, gamma=gamma, beta=beta, moving_mean=moving_mean, - fix_gamma=fix_gamma, moving_var=moving_var, momentum=0.9, - axis=axis, use_global_stats=use_global_stats, eps=1e-5) - if axis == 1: - check_batch_norm(sym, - {'data': data_shape}, {'gamma': gamma_shape, 'beta': beta_shape}, - {'moving_mean': moving_mean_shape, 'moving_var': moving_var_shape}, - atol_fp32=2e-7) - else: - check_unsupported_single_sym(sym) - - -@with_seed() -def test_clip(): - data = mx.sym.Variable('data') - sym = mx.sym.clip(data, 0.25, 0.75) - for data_shape in [(10, 32), (10, 3, 32), (10, 3, 32, 32), (10,3,7,32,32)]: - check_single_sym(sym, {'data': data_shape}, - rtol_fp32=0., atol_fp32=0., - rtol_fp16=1e-3, atol_fp16=0.) - - -@with_seed() -def test_concat(): - lhs = mx.sym.Variable('lhs') - rhs = mx.sym.Variable('rhs') - shape = [3, 5, 7, 9] - lhs_shape = tuple(shape) - for axis in range(1, 4): - sym = mx.sym.concat(lhs, rhs, dim=axis) - rhs_shape = copy.copy(shape) - rhs_shape[axis] = 1 - rhs_shape = tuple(rhs_shape) - check_single_sym(sym, {'lhs': lhs_shape, 'rhs': rhs_shape}, - rtol_fp32=0., atol_fp32=0., rtol_fp16=1e-3, atol_fp16=1e-7) - - -@with_seed() -def test_elemwise_ops(): - lhs = mx.sym.Variable('lhs') - rhs = mx.sym.Variable('rhs') - shape = (3, 5, 7, 9) - lhs_shape = tuple(shape) - sym = mx.sym.elemwise_add(lhs, rhs) - check_single_sym(sym, {'lhs': shape, 'rhs': shape}, - rtol_fp32=0., atol_fp32=0.) - - sym = mx.sym.elemwise_sub(lhs, rhs) - # TODO(cfujitsang): is atol_fp16 ok ? - check_single_sym(sym, {'lhs': shape, 'rhs': shape}, - rtol_fp32=0., atol_fp32=0., rtol_fp16=1e-3, atol_fp16=1e-3) - - sym = mx.sym.elemwise_mul(lhs, rhs) - check_single_sym(sym, {'lhs': shape, 'rhs': shape}, - rtol_fp32=0., atol_fp32=0., rtol_fp16=5e-3, atol_fp16=1e-7) - -@with_seed() -def test_flatten(): - data = mx.sym.Variable('data') - sym = mx.sym.flatten(data) - for data_shape in [(3, 5, 7), (3, 5, 7, 9), (3, 5, 7, 9, 11)]: - check_single_sym(sym, {'data': data_shape}, - rtol_fp32=0., atol_fp32=0., atol_fp16=1e-7) - -@with_seed() -def test_dropout(): - data = mx.sym.Variable('data') - for data_shape in [(3, 5), (3, 5, 7), (3, 5, 7, 9)]: - for mode in ['training', 'always']: - sym = mx.sym.Dropout(data, p=0.7, mode=mode) - if mode == 'training': - check_single_sym(sym, {'data': data_shape}, - rtol_fp32=0., atol_fp32=0., atol_fp16=1e-7) - else: - check_unsupported_single_sym(sym) - sym = mx.sym.Dropout(data, p=0.7, mode=mode, axes=(0,)) - check_unsupported_single_sym(sym) - -if __name__ == "__main__": - import nose - nose.runmodule() diff --git a/tests/python/tensorrt/test_resnet18.py b/tests/python/tensorrt/test_resnet18.py deleted file mode 100644 index 9fd99abb121b..000000000000 --- a/tests/python/tensorrt/test_resnet18.py +++ /dev/null @@ -1,74 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from mxnet.gluon.model_zoo import vision -from mxnet.test_utils import assert_almost_equal -import mxnet as mx -import numpy as np -import os - -batch_shape = (1, 3, 224, 224) -url = 'https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true' -model_file_name = 'resnet18_v2_trt_test' - -def get_image(image_url): - fname = mx.test_utils.download(image_url, fname=image_url.split('/')[-1].split('?')[0]) - img = mx.image.imread(fname) - img = mx.image.imresize(img, 224, 224) # Resize - img = img.transpose((2, 0, 1)) # Channel first - img = img.expand_dims(axis=0) # Batchify - img = mx.nd.cast(img, dtype=np.float32) - return img / 255.0 - -def test_tensorrt_resnet18_feature_vect(): - print("downloading sample input") - input_data = get_image(url) - gluon_resnet18 = vision.resnet18_v2(pretrained=True) - gluon_resnet18.hybridize() - gluon_resnet18.forward(input_data) - gluon_resnet18.export(model_file_name) - sym, arg_params, aux_params = mx.model.load_checkpoint(model_file_name, 0) - - executor = sym.simple_bind(ctx=mx.gpu(), data=batch_shape, - grad_req='null', force_rebind=True) - executor.copy_params_from(arg_params, aux_params) - y = executor.forward(is_train=False, data=input_data) - trt_sym = sym.get_backend_symbol('TensorRT') - arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params) - original_precision_value = mx.contrib.tensorrt.get_use_fp16() - try: - mx.contrib.tensorrt.set_use_fp16(True) - executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape, - grad_req='null', force_rebind=True) - executor.copy_params_from(arg_params, aux_params) - y_trt = executor.forward(is_train=False, data=input_data) - mx.contrib.tensorrt.set_use_fp16(False) - executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape, - grad_req='null', force_rebind=True) - executor.copy_params_from(arg_params, aux_params) - y_trt_fp32 = executor.forward(is_train=False, data=input_data) - no_trt_output = y[0].asnumpy()[0] - trt_output = y_trt[0].asnumpy()[0] - trt_fp32_output = y_trt_fp32[0].asnumpy()[0] - assert_almost_equal(no_trt_output, trt_output, 1e-1, 1e-2) - assert_almost_equal(no_trt_output, trt_fp32_output, 1e-4, 1e-4) - finally: - mx.contrib.tensorrt.set_use_fp16(original_precision_value) - -if __name__ == '__main__': - import nose - nose.runmodule() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py deleted file mode 100644 index 78f41ca53909..000000000000 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import numpy as np -import mxnet as mx -from ctypes.util import find_library - -def check_tensorrt_installation(): - assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library" - -def get_iters(mnist, batch_size): - """Get MNIST iterators.""" - train_iter = mx.io.NDArrayIter(mnist['train_data'], - mnist['train_label'], - batch_size, - shuffle=True) - val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) - test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) - all_test_labels = np.array(mnist['test_label']) - return train_iter, val_iter, test_iter, all_test_labels - -def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt): - """Run inference with either MXNet or TensorRT""" - - data_size = (batch_size,) + mnist['test_data'].shape[1:] - type_dict = {'data': 'float32', 'softmax_label': 'float32'} - - if use_tensorrt: - _sym = sym.get_backend_symbol('TensorRT') - arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(_sym, arg_params, - aux_params) - else: - _sym = sym - for k, v in arg_params.items(): - type_dict[k] = v.dtype - for k, v in aux_params.items(): - type_dict[k] = v.dtype - executor = _sym.simple_bind(ctx=mx.gpu(0), - type_dict=type_dict, - data=data_size, - softmax_label=(batch_size,), - grad_req='null', - force_rebind=True) - executor.copy_params_from(arg_params, aux_params) - - # Get this value from all_test_labels - # Also get classes from the dataset - num_ex = 10000 - all_preds = np.zeros([num_ex, 10]) - test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) - - example_ct = 0 - - for idx, dbatch in enumerate(test_iter): - executor.arg_dict["data"][:] = dbatch.data[0] - executor.forward(is_train=False) - offset = idx*batch_size - extent = batch_size if num_ex - offset > batch_size else num_ex - offset - all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent] - example_ct += extent - - all_preds = np.argmax(all_preds, axis=1) - matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum() - - percentage = 100.0 * matches / example_ct - - return percentage - - -def test_tensorrt_inference(): - """Run LeNet-5 inference comparison between MXNet and TensorRT.""" - check_tensorrt_installation() - mnist = mx.test_utils.get_mnist() - num_epochs = 10 - batch_size = 128 - model_name = 'lenet5' - model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") - model_file = '%s/%s-symbol.json' % (model_dir, model_name) - params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) - - _, _, _, all_test_labels = get_iters(mnist, batch_size) - - # Load serialized MXNet model (model-symbol.json + model-epoch.params) - sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) - - print("LeNet-5 test") - print("Running inference in MXNet") - mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, - batch_size=batch_size, use_tensorrt=False) - - print("Running inference in MXNet-TensorRT") - trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, - batch_size=batch_size, use_tensorrt=True) - - print("MXNet accuracy: %f" % mx_pct) - print("MXNet-TensorRT accuracy: %f" % trt_pct) - - absolute_accuracy_diff = abs(mx_pct - trt_pct) - epsilon = 3e-2 - assert absolute_accuracy_diff < epsilon, \ - """Absolute diff. between MXNet & TensorRT accuracy (%f) exceeds threshold (%f): - MXNet = %f, TensorRT = %f""" % (absolute_accuracy_diff, epsilon, mx_pct, trt_pct) - -if __name__ == '__main__': - import nose - nose.runmodule() diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py index d00f1494e4d5..9c62b7ff90f6 100644 --- a/tests/python/unittest/test_extensions.py +++ b/tests/python/unittest/test_extensions.py @@ -130,8 +130,6 @@ def test_subgraph(): sym = mx.sym.log(d) args = {'a':mx.nd.ones((3,2),ctx=mx.cpu()), 'b':mx.nd.ones((3,2),ctx=mx.cpu())} - arg_array = [mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu()), - mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu())] # baseline - regular execution in MXNet exe = sym.bind(ctx=mx.cpu(), args=args) @@ -147,14 +145,14 @@ def test_subgraph(): # with propogating shapes/types, rejecting subgraph # this tests creating the subgraph and having the subgraph prop reject it - mysym2 = sym.optimize_for("myProp", arg_array, reject=True) + mysym2 = sym.optimize_for("myProp", args, reject=True) exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() # check that result matches one executed by MXNet assert_almost_equal(out[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3) # with propogating shapes/types - mysym3 = sym.optimize_for("myProp",arg_array) + mysym3 = sym.optimize_for("myProp",args) exe3 = mysym3.bind(ctx=mx.cpu(), args=args) out3 = exe3.forward() # check that result matches one executed by MXNet diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index e414a9836ccb..81665f20057c 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -327,18 +327,20 @@ def check_subgraph_exe8(sym, subgraph_backend, op_names): then bind and compare results of the partitioned sym and the original sym.""" # bind arg_shapes, _, aux_shapes = sym.infer_shape() - arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes] - aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] - exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + arg_names = sym.list_arguments() + aux_names = sym.list_auxiliary_states() + arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)} + aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)} + exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') exe1.forward() # infer shape/type before partition before bind check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), - c_str_array(op_names))) - part_sym = sym.optimize_for(subgraph_backend, arg_array, aux_array) + c_str_array(op_names))) + part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict) check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) - exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') exe2.forward() # compare outputs From 351104759cb8ea69df9cc450a69c962e674015e4 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 18 Aug 2020 15:40:17 -0700 Subject: [PATCH 2/3] Remove test from Jenkins Signed-off-by: Serge Panev --- ci/jenkins/Jenkinsfile_unix_gpu | 1 - 1 file changed, 1 deletion(-) diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu index 5e26a9f41380..f21944084a72 100644 --- a/ci/jenkins/Jenkinsfile_unix_gpu +++ b/ci/jenkins/Jenkinsfile_unix_gpu @@ -50,7 +50,6 @@ core_logic: { custom_steps.test_unix_python3_quantize_gpu(), custom_steps.test_unix_python3_mkldnn_gpu(), custom_steps.test_unix_python3_mkldnn_nocudnn_gpu(), - custom_steps.test_unix_python3_tensorrt_gpu(), custom_steps.test_unix_perl_gpu(), custom_steps.test_unix_r_gpu(), custom_steps.test_unix_cpp_gpu(), From a62e081496d0efd53682dfc1fcf9ca38d39e8f51 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 18 Aug 2020 16:42:04 -0700 Subject: [PATCH 3/3] Fix test Signed-off-by: Serge Panev --- example/extensions/lib_pass/test_pass.py | 26 ++++-------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py index a5a7e8c7ea4d..66411a69cac6 100644 --- a/example/extensions/lib_pass/test_pass.py +++ b/example/extensions/lib_pass/test_pass.py @@ -52,30 +52,12 @@ def test_model(pass_name): # execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') - - exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) - out = exe.forward() + inputs = [a,b] + sym_block = nn.SymbolBlock(sym, inputs) + sym_block.initialize() + out = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2))) print(out) - # Symbol optimize_for - # with propogating shapes/types - print('-------------------------------') - print('Testing pass "%s" with shapes/types' % pass_name) - aux = {} - mysym2 = sym.optimize_for(pass_name,args,aux) - print(mysym2.tojson()) - exe2 = mysym2.bind(ctx=mx.cpu(), args=args) - out2 = exe2.forward() - print(out2) - - # without propogating shapes/types - print('-------------------------------') - print('Testing pass "%s" without shapes/types' % pass_name) - mysym3 = sym.optimize_for(pass_name, myOpt='yello') - exe3 = mysym3.bind(ctx=mx.cpu(), args=args) - out3 = exe3.forward() - print(out3) - # Gluon Hybridize print('-------------------------------') print('Testing pass "%s" Gluon Hybridize with shapes/types' % pass_name)