diff --git a/.gitmodules b/.gitmodules index 836d824a6f5a..900500a647a9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -29,3 +29,6 @@ [submodule "3rdparty/onnx-tensorrt"] path = 3rdparty/onnx-tensorrt url = https://github.com/onnx/onnx-tensorrt.git +[submodule "3rdparty/ngraph-mxnet-bridge"] + path = 3rdparty/ngraph-mxnet-bridge + url = https://github.com/NervanaSystems/ngraph-mxnet-bridge diff --git a/3rdparty/ngraph-mxnet-bridge b/3rdparty/ngraph-mxnet-bridge new file mode 160000 index 000000000000..9af5ed90f273 --- /dev/null +++ b/3rdparty/ngraph-mxnet-bridge @@ -0,0 +1 @@ +Subproject commit 9af5ed90f273dd97f305abacef6b3ff3a682efbe diff --git a/CMakeLists.txt b/CMakeLists.txt index d8ef524bb389..7c7a4db0f11f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ mxnet_option(USE_CUDNN "Build with cudnn support" ON) # one could se mxnet_option(USE_SSE "Build with x86 SSE instruction support" ON IF NOT ARM) mxnet_option(USE_F16C "Build with x86 F16C instruction support" ON) # autodetects support if ON mxnet_option(USE_LAPACK "Build with lapack support" ON) +mxnet_option(USE_NGRAPH "Build with nGraph support" OFF) mxnet_option(USE_MKL_IF_AVAILABLE "Use MKL if found" ON) mxnet_option(USE_MKLML_MKL "Use MKLDNN variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND (NOT APPLE)) mxnet_option(USE_MKLDNN "Use MKLDNN variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND (NOT APPLE) AND (NOT MSVC) AND (CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64") AND (NOT CMAKE_CROSSCOMPILING)) @@ -230,7 +231,6 @@ if(ENABLE_TESTCOVERAGE) if(NOT GCOV_PATH) message(FATAL_ERROR "gcov not found! Aborting...") endif() # NOT GCOV_PATH - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --coverage") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --coverage") set(CMAKE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} --coverage") diff --git a/Makefile b/Makefile index 5c5e77fe3a58..12354cfc1464 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +.DEFAULT_GOAL := all ROOTDIR = $(CURDIR) TPARTYDIR = $(ROOTDIR)/3rdparty @@ -78,13 +79,17 @@ ifeq ($(USE_MKLDNN), 1) MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install export USE_MKLML = 1 + MKLDNN_INCLUDE_DIR = $(MKLDNNROOT)/include + MKLDNN_LIB_DIR = $(MKLDNNROOT)/lib endif include $(TPARTYDIR)/mshadow/make/mshadow.mk include $(DMLC_CORE)/make/dmlc.mk -# all tge possible warning tread -WARNFLAGS= -Wall -Wsign-compare +include 3rdparty/ngraph-mxnet-bridge/ngraph.mk + +# all the possible warning tread +WARNFLAGS= -Wall -Wsign-compare -Wno-comment CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS) ifeq ($(DEV), 1) @@ -101,6 +106,10 @@ endif CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) +ifeq ($(USE_NGRAPH),1) + CFLAGS += $(NGRAPH_CFLAGS) +endif + ifeq ($(ENABLE_TESTCOVERAGE), 1) CFLAGS += --coverage LDFLAGS += --coverage @@ -411,6 +420,10 @@ else EXTRA_CUOBJ = endif +ifeq ($(USE_NGRAPH), 1) + EXTRA_OBJ += $(NGRAPH_BRIDGE_OBJ) +endif + # plugin PLUGIN_OBJ = PLUGIN_CUOBJ = @@ -473,23 +486,23 @@ endif # For quick compile test, used smaller subset ALLX_DEP= $(ALL_DEP) -build/src/%.o: src/%.cc | mkldnn +build/src/%.o: src/%.cc | mkldnn ngraph @mkdir -p $(@D) $(CXX) -std=c++11 -c $(CFLAGS) -MMD -c $< -o $@ -build/src/%_gpu.o: src/%.cu | mkldnn +build/src/%_gpu.o: src/%.cu | mkldnn ngraph @mkdir -p $(@D) $(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -Xcompiler "$(CFLAGS)" --generate-dependencies -MT build/src/$*_gpu.o $< >build/src/$*_gpu.d $(NVCC) -c -o $@ $(NVCCFLAGS) $(CUDA_ARCH) -Xcompiler "$(CFLAGS)" $< # A nvcc bug cause it to generate "generic/xxx.h" dependencies from torch headers. # Use CXX to generate dependency instead. -build/plugin/%_gpu.o: plugin/%.cu +build/plugin/%_gpu.o: plugin/%.cu | ngraph @mkdir -p $(@D) $(CXX) -std=c++11 $(CFLAGS) -MM -MT build/plugin/$*_gpu.o $< >build/plugin/$*_gpu.d $(NVCC) -c -o $@ $(NVCCFLAGS) $(CUDA_ARCH) -Xcompiler "$(CFLAGS)" $< -build/plugin/%.o: plugin/%.cc | mkldnn +build/plugin/%.o: plugin/%.cc | mkldnn ngraph @mkdir -p $(@D) $(CXX) -std=c++11 -c $(CFLAGS) -MMD -c $< -o $@ @@ -515,7 +528,9 @@ lib/libmxnet.a: $(ALLX_DEP) lib/libmxnet.so: $(ALLX_DEP) @mkdir -p $(@D) - $(CXX) $(CFLAGS) -shared -o $@ $(filter-out %libnnvm.a, $(filter %.o %.a, $^)) $(LDFLAGS) \ + $(CXX) $(CFLAGS) -shared -o $@ $(filter-out %libnnvm.a, $(filter %.o %.a, $^)) \ + $(NGRAPH_LDFLAGS_FOR_SHARED_LIBS) \ + $(LDFLAGS) \ -Wl,${WHOLE_ARCH} $(filter %libnnvm.a, $^) -Wl,${NO_WHOLE_ARCH} ifeq ($(USE_MKLDNN), 1) ifeq ($(UNAME_S), Darwin) @@ -544,7 +559,9 @@ bin/im2rec: tools/im2rec.cc $(ALLX_DEP) $(BIN) : @mkdir -p $(@D) - $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) + $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) \ + $(LDFLAGS) \ + $(NGRAPH_LDFLAGS_FOR_PROGS_IN_BIN) # CPP Package ifeq ($(USE_CPP_PACKAGE), 1) @@ -656,7 +673,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN) $(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS)) $(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS)) else -clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN) +clean: rclean ngraph_clean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN) $(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ (cd scala-package && mvn clean) || true cd $(DMLC_CORE); $(MAKE) clean; cd - diff --git a/NGRAPH_README.md b/NGRAPH_README.md new file mode 100644 index 000000000000..0eefcad97ebc --- /dev/null +++ b/NGRAPH_README.md @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + +# nGraph - MXNet Integration +MXNet nGraph integration is based on [Unified integration with external backend libraries](https://cwiki.apache.org/confluence/display/MXNET/Unified+integration+with+external+backend+libraries) + +After building MXNet with nGraph support, users can enable nGraph backend by setting `MXNET_SUBGRAPH_BACKEND="ngraph"`environmental variable. + +Gluon support is experimental and may or may not yield good performance. Gluon-NGraph +integration can be enabled by setting the environmental variable `MXNET_NGRAPH_GLUON=1` + +## Building with nGraph support +MXNet's experimental support for the Intel nGraph graph compiler can be enabled +using MXNet's build system. Current support is for Linux-based OS's, Mac and Windows +support will be added in future releases. + +When building MXNet with experimental nGraph integration enabled, MXNet's build +system builds its own copy of the nGraph-supplied libraries. Upon successful +completion of an nGraph-enabled build, these libraries and related symbolic links +can be found in the same build directory as `libmxnet.so`. + +If building with gnu make, use the command: + +`make -j USE_NGRAPH=1` + +If building with cmake, use the command: + +`mkdir build && cd build && cmake ../ -DUSE_NGRAPH=1 && make -j` + +## Runtime environment variables +Some environment variables influence the behavior of the +nGraph-enabled MXNet software and supporting libraries. Here is a partial list of those variables: + +| Variable | Description | +| :-------- | :---------- | +| `OMP_NUM_THREADS` | Suggested value: `16`. For more information please see [here](https://software.intel.com/en-us/mkl-windows-developer-guide-setting-the-number-of-threads-using-an-openmp-environment-variable) | +| `KMP_AFFINITY` | Suggested value: `granularity=fine,compact,1,0`. For more information please see [here](https://software.intel.com/en-us/node/522691). | +| `MXNET_NGRAPH_VERBOSE_GRAPH` | When set to `1`, nGraph-enabled MXNet will create in the current directory a JSON file representing each subgraph being compiled by the nGraph library. Each of these JSON files is a graph serialization that can be loaded by nGraph's `ngraph::deserialize` functions. | + +## Supported nGraph back-ends +The nGraph library supports a number of hardware and software backends, including `"CPU"`, `"INTERPETER"` (reference kernels), `"GPU"`, and `"IntelGPU"`. Current experimental integration enables `"CPU"` backend by default. More backends will be supported in future releases. diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py index 0a4be02b8ff9..20213067c68c 100644 --- a/amalgamation/amalgamation.py +++ b/amalgamation/amalgamation.py @@ -30,7 +30,8 @@ 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h', 'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h', 'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp', - 'relacy_shims.h', 'ittnotify.h', 'shared_mutex' + 'relacy_shims.h', 'ittnotify.h', 'shared_mutex', 'ngraph/ngraph.hpp', 'ngraph_imperative.h', + 'ngraph_nnvm_utils.h', ] minimum = int(sys.argv[6]) if len(sys.argv) > 5 else 0 diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index de1b7795ce69..70c97a1d1743 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -564,6 +564,18 @@ build_ubuntu_cpu_mkldnn_mkl() { -j$(nproc) } +build_ubuntu_cpu_ngraph() { + set -ex + + build_ccache_wrappers + + make \ + ENABLE_TESTCOVERAGE=1 \ + USE_BLAS=openblas \ + USE_NGRAPH=1 \ + -j$(nproc) +} + build_ubuntu_gpu() { build_ubuntu_gpu_cuda91_cudnn7 } @@ -834,6 +846,16 @@ unittest_ubuntu_tensorrt_gpu() { nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/ } +unittest_ubuntu_cpu_ngraph() { + set -ex + export MXNET_SUBGRAPH_BACKEND="ngraph" + export PYTHONPATH=./python/ + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH + nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_unittest.xml --verbose tests/python/unittest -e "test_monitor" -e "test_op_output_names_monitor" -e "test_op_all_names_monitor" -e "test_zero_prop" + nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_ngraph.xml --verbose tests/python/ngraph +} + # quantization gpu currently only runs on P3 instances # need to separte it from unittest_ubuntu_python2_gpu() unittest_ubuntu_python2_quantization_gpu() { diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index cfbf484756e5..fb8c439d8678 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -34,6 +34,7 @@ mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/li mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_ngraph_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, lib/libcpu_backend.so, lib/libngraph.so, lib/libtbb.so.2, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'build/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/cpp-package/example/*' @@ -119,6 +120,20 @@ def compile_unix_openblas_debug_cpu() { }] } +def compile_unix_ngraph_cpu() { + return ['CPU: NGRAPH': { + node(NODE_LINUX_CPU) { + ws('workspace/build-ngraph-cpu') { + timeout(time: max_time, unit: 'MINUTES') { + utils.init_git() + utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_ngraph', false) + utils.pack_lib('ngraph_cpu', mx_ngraph_lib, true) + } + } + } + }] +} + def compile_unix_mkl_cpu() { return ['CPU: MKL': { node(NODE_LINUX_CPU) { @@ -717,6 +732,24 @@ def test_unix_python2_mkldnn_cpu() { }] } +def test_unix_python3_ngraph_cpu() { + return ['Python3: nGraph-CPU': { + node(NODE_LINUX_CPU) { + ws('workspace/build-ngraph-cpu') { + timeout(time: max_time, unit: 'MINUTES') { + try { + utils.unpack_and_init('ngraph_cpu', mx_ngraph_lib, true) + utils.docker_run('ubuntu_cpu', 'unittest_ubuntu_cpu_ngraph', false) + utils.publish_test_coverage() + } finally { + utils.collect_test_results_unix('nosetests_unittest.xml', 'nosetests_python3_ngraph_cpu.xml') + } + } + } + } + }] +} + def test_unix_python3_mkldnn_cpu() { return ['Python3: MKLDNN-CPU': { node(NODE_LINUX_CPU) { diff --git a/ci/jenkins/Jenkinsfile_unix_cpu b/ci/jenkins/Jenkinsfile_unix_cpu index 919381ebccd4..dbdfddcbcba1 100644 --- a/ci/jenkins/Jenkinsfile_unix_cpu +++ b/ci/jenkins/Jenkinsfile_unix_cpu @@ -37,6 +37,7 @@ core_logic: { custom_steps.compile_unix_cpu_openblas(), custom_steps.compile_unix_openblas_debug_cpu(), custom_steps.compile_unix_mkl_cpu(), + custom_steps.compile_unix_ngraph_cpu(), custom_steps.compile_unix_mkldnn_cpu(), custom_steps.compile_unix_mkldnn_mkl_cpu() ]) @@ -48,6 +49,7 @@ core_logic: { custom_steps.test_unix_python3_mkl_cpu(), custom_steps.test_unix_python2_mkldnn_cpu(), custom_steps.test_unix_python3_mkldnn_cpu(), + custom_steps.test_unix_python3_ngraph_cpu(), custom_steps.test_unix_python3_mkldnn_mkl_cpu(), custom_steps.test_unix_scala_cpu(), custom_steps.test_unix_scala_mkldnn_cpu(), diff --git a/make/config.mk b/make/config.mk index 8a1aa2c165c4..f73d10d3027e 100644 --- a/make/config.mk +++ b/make/config.mk @@ -92,6 +92,7 @@ USE_OPENCV = 1 #whether use libjpeg-turbo for image decode without OpenCV wrapper USE_LIBJPEG_TURBO = 0 + #add the path to libjpeg-turbo library USE_LIBJPEG_TURBO_PATH = NONE @@ -103,6 +104,9 @@ USE_OPENMP = 1 # you can disable it explicity with USE_MKLDNN = 0 USE_MKLDNN = +# whether to use the nGraph library +USE_NGRAPH = 0 + # whether use NNPACK library USE_NNPACK = 0 diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index e07716267288..6b0cfeb29940 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -32,6 +32,9 @@ #include "../operator/operator_common.h" #include "../executor/exec_pass.h" #include "../operator/subgraph/subgraph_property.h" +#if MXNET_USE_NGRAPH == 1 +#include +#endif namespace mxnet { namespace op { @@ -76,6 +79,10 @@ int MXListAllOpNames(nn_uint *out_size, const char ***out_array) { mxnet::op::RegisterLegacyOpProp(); mxnet::op::RegisterLegacyNDFunc(); +#if MXNET_USE_NGRAPH == 1 + // ngraph imperative interface + ngraph_bridge::InitImperative(); +#endif return NNListAllOpNames(out_size, out_array); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 436eae37d785..980b67069e8b 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1532,6 +1532,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes); subgraph_prop->SetAttr("graph", g); + subgraph_prop->SetAttr("grad_reqs", grad_req_types); auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name); // assign a op name set to the subgraph property if it has been provided by users if (it != op::SubgraphPropertyOpNameSet::Get()->end()) { @@ -1661,10 +1662,12 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, std::unordered_map* shared_buffer, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); - if (!exec->subgraph_property().empty()) { + if (!exec->subgraph_property().empty() && group2ctx.empty()) { symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map, arg_stype_map, default_ctx, group2ctx, in_arg_ctxes, aux_state_ctxes, grad_req_types); + } else if (!group2ctx.empty()) { + LOG(WARNING) << "MXNET_SUBGRAPH_BACKEND does not currently support heterogeneous execution"; } exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, @@ -1686,8 +1689,12 @@ Executor *Executor::Bind(nnvm::Symbol symbol, auto exec = new exec::GraphExecutor(); std::vector tmp_in_args = in_args; if (!exec->subgraph_property().empty()) { - symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states, + if (group2ctx.empty()) { + symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states, default_ctx, group2ctx, grad_req_type); + } else { + LOG(WARNING) << "MXNET_SUBGRAPH_BACKEND does not currently support heterogeneous execution"; + } } exec->Init(symbol, default_ctx, group2ctx, tmp_in_args, arg_grad_store, grad_req_type, aux_states, diff --git a/src/operator/contrib/ngraph-inl.h b/src/operator/contrib/ngraph-inl.h new file mode 100644 index 000000000000..003ac7f05c04 --- /dev/null +++ b/src/operator/contrib/ngraph-inl.h @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 Intel Corporation + * \file ngraph.cc + * \brief ngraph subgraph property for mxnet +*/ + +#ifndef MXNET_OPERATOR_CONTRIB_NGRAPH_INL_H_ +#define MXNET_OPERATOR_CONTRIB_NGRAPH_INL_H_ + +#if MXNET_USE_NGRAPH +#include +#include +#include +#include + +#include + +#include "../subgraph/common.h" +#include "../subgraph/subgraph_property.h" + +namespace mxnet { +namespace op { + +class SgNgraphSelector : public SubgraphSelector { + public: + // Public methods to implement the subgraph selector API + explicit SgNgraphSelector(std::shared_ptr compiler) + : compiler_(compiler), valid(compiler_->get_node_map().size() > 0) {} + + bool Select(const nnvm::Node &n) override { return is_node_selected(n); } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return is_node_selected(n, &new_node); + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return is_node_selected(n, &new_node); + } + std::vector Filter( + const std::vector &candidates) { + if (candidates.size() == 1 && candidates[0]->inputs.size() == 0) { + return std::vector(); + } else { + return candidates; + } + } + + private: + const std::shared_ptr compiler_; + const bool valid; + // get_node is a utility function to translate NNVM Nodes to + // the IR nodes inside the ngraph_bridge::Compiler, this is + // primarily utilized to help determine nGraph support + ngraph_bridge::NodePtr get_node(const nnvm::Node *n) { + if (n) { + auto &entry_map = compiler_->get_ngraph().entry_map_; + ngraph_bridge::MapEntry tmp{compiler_->get_node_map().at(n).get(), 0}; + if (entry_map.count(tmp)) { + return entry_map[tmp]; + } + } + return nullptr; + } + // is_node_selected queries the ngraph_bridge::Compiler to determine if both + // current and next NNVM nodes are supported by nGraph. + // This allows us to meld nGraph's analysis with PartitionGraph. + bool is_node_selected(const nnvm::Node &n, const nnvm::Node *next = nullptr) { + bool selected = false; + if (!valid) return selected; + + auto nn = get_node(&n); + auto nnext = get_node(next); + + selected = nn && nn->in_ngraph_; + if (next) { + selected = + selected && nnext->in_ngraph_ && nn->subgraph_ == nnext->subgraph_; + } + return selected; + } +}; + +class SgNgraphProperty : public SubgraphProperty { + public: + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + + bool NeedGraphAttrs() const override { return true; } + // Create a subgraph node based on a symbol + nnvm::NodePtr CreateSubgraphNode( + const nnvm::Symbol &sym, const int subgraph_id = 0) const override { + nnvm::NodePtr n = nnvm::Node::Create(); + n->attrs.op = Op::Get("_ngraph_subgraph_op"); + n->attrs.name = "_ngraph_subgraph_op" + std::to_string(subgraph_id); + n->attrs.subgraphs.push_back(std::make_shared(sym)); + return n; + } + // Create a subgraph node based on a graph with inferred shapes, types + // and storage types, then compile it with nGraph and store the + // ngraph_bridge::Compiler object in NNVM's node attributes for execution. + nnvm::NodePtr CreateSubgraphNode( + const nnvm::Graph &sg, const int subgraph_id = 0) const override { + nnvm::Symbol sym; + sym.outputs = sg.outputs; + auto n = CreateSubgraphNode(sym, subgraph_id); + auto grad_req_map = GetAttr>("grad_reqs"); + auto compiler = std::make_shared(sg, grad_req_map); + compiler->GetNgraph(); + n->attrs.parsed = compiler; + return n; + } + // Create a Subgraph Selector with an embedded ngraph_bridge::Compiler for + // nGraph support analysis + SubgraphSelectorPtr CreateSubgraphSelector() const override { + if (!compiler_) { + auto &orig_graph = GetAttr("graph"); + auto grad_req_map = GetAttr>("grad_reqs"); + compiler_ = std::make_shared(orig_graph, + grad_req_map, true); + } + return std::make_shared(compiler_); + } + + private: + mutable std::shared_ptr compiler_; +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_NGRAPH + +#endif // MXNET_OPERATOR_CONTRIB_NGRAPH_INL_H_ diff --git a/src/operator/contrib/ngraph.cc b/src/operator/contrib/ngraph.cc new file mode 100644 index 000000000000..5281550bf5c1 --- /dev/null +++ b/src/operator/contrib/ngraph.cc @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 Intel Corporation + * \file ngraph.cc + * \brief ngraph subgraph property for mxnet +*/ + +#if MXNET_USE_NGRAPH +#include +#include +#include +#include +#include +#include + +#include "../subgraph/common.h" +#include "../subgraph/subgraph_property.h" +#include "./ngraph-inl.h" + +namespace mxnet { +namespace op { + +std::shared_ptr get_ngraph(const NodeAttrs &attrs) { + auto compiler = + nnvm::get>(attrs.parsed); + return compiler->GetNgraph(); +} + +class NgraphSubgraphOperator { + public: + explicit NgraphSubgraphOperator(std::shared_ptr ngraph) + : ngraph_(ngraph) {} + void Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + void Backward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + private: + std::shared_ptr ngraph_; +}; + +void NgraphSubgraphOperator::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + compute_forward(ctx, ngraph_, inputs, req, outputs); +} + +void NgraphSubgraphOperator::Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + compute_backward(ctx, ngraph_, inputs, req, outputs); +} + +OpStatePtr CreateNgraphSubgraphOpState(const NodeAttrs &attrs, Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(get_ngraph(attrs)); +} + +void NgraphSubgraphOpForward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + NgraphSubgraphOperator &op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +void NgraphSubgraphOpBackward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + NgraphSubgraphOperator &op = state_ptr.get_state(); + op.Backward(ctx, inputs, req, outputs); +} + +std::vector NgraphSubgraphGradient( + const nnvm::NodePtr &n, const std::vector &ograds) { + auto graph = get_ngraph(n->attrs); + const bool zero_grad = check_zero_grad(graph); + graph->zero_grad = zero_grad; + auto is_loss = graph->is_loss; + auto p = nnvm::Node::Create(); + p->attrs.op = nnvm::Op::Get("_backward_ngraph_subgraph_op"); + p->attrs.parsed = n->attrs.parsed; + if (std::find(begin(is_loss), end(is_loss), true) == end(is_loss) && + zero_grad && graph->num_outputs_ == 1) { + return mxnet::op::MakeZeroGradNodes(n, ograds); + } + if (!graph->need_grad) { + LOG(FATAL) + << "NGRAPH_BRIDGE: This graph was compiled as inference but " + << "is called in training"; + } + p->attrs.name = n->attrs.name + "_backward"; + p->attrs.dict = n->attrs.dict; + p->control_deps.emplace_back(n); + if (p->op()->attr_parser != nullptr) { + p->op()->attr_parser(&(p->attrs)); + } + if (!zero_grad) { + for (size_t i = 0; i < graph->num_adjoints_; ++i) { + if (!is_loss[i]) { + p->inputs.push_back(ograds[i]); + } + } + } + p->inputs.insert(p->inputs.end(), n->inputs.begin(), n->inputs.end()); + for (unsigned i = graph->outputs_.size(); + i < graph->fprop_cache->fprop->get_results().size(); ++i) { + p->inputs.emplace_back(nnvm::NodeEntry{n, i, 0}); + } + std::vector ret; + for (unsigned i = 0; i < p->num_outputs(); ++i) { + ret.emplace_back(nnvm::NodeEntry{p, i, 0}); + } + return ret; +} + +std::vector NgraphSubgraphListNodeNames( + const std::vector &nodes) { + std::vector names; + for (const auto &n : nodes) { + names.emplace_back(n->name_); + } + return names; +} +std::vector NgraphSubgraphListInputNames( + const nnvm::NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + return NgraphSubgraphListNodeNames(graph->inputs_); +} +std::vector NgraphSubgraphListOutputNames( + const nnvm::NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + auto names = NgraphSubgraphListNodeNames(graph->outputs_); + for (size_t i = names.size(); i < graph->get_results().size(); ++i) { + names.push_back(graph->name_ + "_output_" + std::to_string(i)); + } + return names; +} +bool NgraphSubgraphInferShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + auto compiler = + nnvm::get>(attrs.parsed); + auto graph = get_ngraph(attrs); + + ngraph_check(in_attrs != nullptr); + ngraph_check(out_attrs != nullptr); + ngraph_check(in_attrs->size() == graph->inputs_.size()); + ngraph_check(out_attrs->size() == graph->get_results().size()); + + if ((graph->inputs_.size() > 0) && + (*in_attrs)[0] != graph->inputs_[0]->shape_) { + compiler->ReshapeGraph(*in_attrs); + graph = compiler->GetNgraph(); + } + for (size_t i = 0; i < graph->inputs_.size(); ++i) { + (*in_attrs)[i] = graph->inputs_[i]->shape_; + } + size_t i = 0; + for (const auto& output : graph->get_results()) { + auto tmp_shape = ngraph_bridge::NShape_to_TShape(output->get_shape()); + (*out_attrs)[i] = tmp_shape; + i += 1; + } + return true; +} +bool NgraphSubgraphInferType(const nnvm::NodeAttrs &attrs, + std::vector *iattr, std::vector *oattr) { + auto graph = get_ngraph(attrs); + + ngraph_check(iattr != nullptr); + ngraph_check(oattr != nullptr); + ngraph_check(iattr->size() == graph->inputs_.size()); + ngraph_check(oattr->size() == graph->get_results().size()); + + for (size_t i = 0; i < graph->inputs_.size(); ++i) { + (*iattr)[i] = graph->inputs_[i]->dtype_; + } + std::vector dtypes; + for (const auto& output : graph->get_results()) { + dtypes.push_back(ngraph_bridge::getType(output->get_element_type())); + } + for (size_t i = 0; i < dtypes.size(); ++i) { + mxnet::op::type_assign(&((*oattr)[i]), dtypes[i]); + } + return true; +} + +bool NgraphSubgraphInferStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + mxnet::DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + ngraph_check(dispatch_mode != nullptr); + ngraph_check(in_attrs != nullptr); + ngraph_check(out_attrs != nullptr); + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + if (in_attrs->size() > 0) { + mxnet::op::storage_type_assign(in_attrs, mxnet::kDefaultStorage, + dispatch_mode, + mxnet::DispatchMode::kFComputeEx); + } + return mxnet::op::storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, + mxnet::DispatchMode::kFComputeEx); +} +bool NgraphSubgraphBackwardInferStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + mxnet::DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + ngraph_check(dispatch_mode != nullptr); + ngraph_check(in_attrs != nullptr); + ngraph_check(out_attrs != nullptr); + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + mxnet::op::storage_type_assign(in_attrs, mxnet::kDefaultStorage, + dispatch_mode, + mxnet::DispatchMode::kFComputeEx); + return mxnet::op::storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, + mxnet::DispatchMode::kFComputeEx); +} +std::vector NGraphSubgraphMutateInputs(const nnvm::NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + std::vector mutate_vars; + for (size_t i = 0; i < graph->inputs_.size(); ++i) { + if (graph->inputs_[i]->type_ == ngraph_bridge::NodeType::kAux) { + mutate_vars.emplace_back(i); + } + } + return mutate_vars; +} + +NNVM_REGISTER_OP(_ngraph_subgraph_op) + .describe(R"code(_ngraph_subgraph_op)code" ADD_FILELINE) + .set_num_inputs([](const NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + return graph->inputs_.size(); + }) + .set_num_outputs([](const NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + return graph->get_results().size(); + }) + .set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + auto graph = get_ngraph(attrs); + return graph->outputs_.size(); + }) + .set_attr("FListInputNames", + NgraphSubgraphListInputNames) + .set_attr("FListOutputNames", + NgraphSubgraphListOutputNames) + .set_attr("FCreateOpState", CreateNgraphSubgraphOpState) + .set_attr("FInferShape", NgraphSubgraphInferShape) + .set_attr("FInferType", NgraphSubgraphInferType) + .set_attr("FInferStorageType", + NgraphSubgraphInferStorageType) + .set_attr("FStatefulComputeEx", + NgraphSubgraphOpForward) + .set_attr("FStatefulComputeEx", + NgraphSubgraphOpForward) + .set_attr("FGradient", NgraphSubgraphGradient) + .set_attr("FMutateInputs", NGraphSubgraphMutateInputs) + .set_attr("key_var_num_args", "num_args") + .add_argument("data", "NDArray-or-Symbol[]", "input data list"); + +NNVM_REGISTER_OP(_backward_ngraph_subgraph_op) + .set_num_inputs([](const NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + int mode = static_cast(ngraph_bridge::GraphExeMode::kTrain); + return graph->fprop_cache->bprop->get_parameters().size() + + graph->cached_aux_positions[mode].size(); + }) + .set_num_outputs([](const NodeAttrs &attrs) { + auto graph = get_ngraph(attrs); + return graph->fprop_cache->bprop->get_results().size(); + }) + .set_attr("TIsBackward", true) + .set_attr("TIsLayerOpBackward", true) + .set_attr("FStatefulComputeEx", + NgraphSubgraphOpBackward) + .set_attr("FStatefulComputeEx", + NgraphSubgraphOpBackward) + .set_attr("FInferStorageType", + NgraphSubgraphBackwardInferStorageType); +MXNET_REGISTER_SUBGRAPH_PROPERTY(ngraph, SgNgraphProperty); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_NGRAPH diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index 90a14caa510b..c0c51ebd3c93 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -30,6 +30,7 @@ #include #include "./subgraph_property.h" +#include "../../executor/exec_pass.h" namespace nnvm { NodePtr CreateVariableNode(const std::string& name); @@ -238,21 +239,23 @@ bool LabelSubgraph(const Graph& g, const std::vector& snodes) { if (ancestor == descendant) return true; std::stack s; + std::unordered_set visited(snodes.begin(), snodes.end()); s.push(descendant); size_t count = 0; while (!s.empty()) { - CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably" - " a loop in the graph"; + CHECK_LT(count, indexed_graph.num_nodes()) + << "Finding ancestor failed. There is probably a loop in the graph"; ++count; const nnvm::Node* top = s.top(); s.pop(); + visited.insert(top); if (top == ancestor) { return true; } for (const auto& entry : top->inputs) { // when searching for the ancestor, the path cannot cross any subgraph node - auto it = std::find(snodes.begin(), snodes.end(), entry.node.get()); - if (it == snodes.end()) { + // or a previously visited node + if (visited.count(entry.node.get()) == 0) { s.push(entry.node.get()); } } @@ -616,6 +619,87 @@ void CutGraphInputs(const std::vector &input_entries, } } +/*! + * \brief Infer attrs for subgraph, given input nodes of subgraph from original graph + */ +nnvm::Graph InferSubgraphAttrs( + Graph* g, const std::vector& orig_input_entries, + const std::unordered_map& subgraphs, + nnvm::Graph&& sg) { + // return if partition without attrs + if (!g->HasAttr("context")) return std::move(sg); + const auto &idx_og = g->indexed_graph(); + const auto &idx_g = sg.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), orig_input_entries.size()); + + auto num_nodes = idx_g.num_node_entries(); + + auto orig_ctx = g->GetAttr("context"); + auto orig_dev_masks = g->GetAttr("dev_mask"); + + auto oshapes = g->GetAttr("shape"); + auto odtypes = g->GetAttr("dtype"); + auto ostypes = g->GetAttr("storage_type"); + + exec::ContextVector contexts(idx_g.num_nodes(), orig_ctx[0]); + mxnet::ShapeVector shapes(num_nodes); + nnvm::DTypeVector types(num_nodes, -1); + StorageTypeVector stypes(num_nodes, kUndefinedStorage); + exec::DevMaskVector dev_masks(idx_g.num_nodes(), orig_ctx[0].dev_mask()); + + nnvm::DFSVisit(sg.outputs, [&](const nnvm::NodePtr node) { + if (idx_og.exist(node.get())) { + auto nid = idx_g.node_id(node.get()); + auto onid = idx_og.node_id(node.get()); + contexts[nid] = orig_ctx[onid]; + dev_masks[nid] = orig_dev_masks[onid]; + } + }); + + // copy shapes/types from original graph if available + const auto &input_nids = idx_g.input_nodes(); + for (size_t i = 0; i < input_nids.size(); i++) { + auto nid = input_nids[i]; + auto eid = idx_g.entry_id(input_nids[i], 0); + uint32_t onid = 0; + uint32_t oeid = 0; + // get nodes ids from original graph, or previous subgraphs. + if (idx_og.exist(orig_input_entries[i].node.get())) { + onid = idx_og.node_id(orig_input_entries[i].node.get()); + oeid = idx_og.entry_id(orig_input_entries[i]); + } else { + auto previous = subgraphs.at(orig_input_entries[i].node.get()); + onid = idx_og.node_id(previous.outputs[orig_input_entries[i].index].node.get()); + oeid = idx_og.entry_id(previous.outputs[orig_input_entries[i].index]); + } + + // copy ctx/mask + contexts[nid] = orig_ctx[onid]; + dev_masks[nid] = orig_dev_masks[onid]; + + // copy shapes/types + shapes[eid] = oshapes[oeid]; + types[eid] = odtypes[oeid]; + stypes[eid] = ostypes[oeid]; + } + + sg.attrs["context"] = std::make_shared(std::move(contexts)); + + sg.attrs["shape"] = std::make_shared(std::move(shapes)); + sg = exec::InferShape(std::move(sg)); + CHECK_EQ(sg.GetAttr("shape_num_unknown_nodes"), 0U); + + sg.attrs["dtype"] = std::make_shared(std::move(types)); + sg = exec::InferType(std::move(sg)); + CHECK_EQ(sg.GetAttr("dtype_num_unknown_nodes"), 0U); + + sg.attrs["dev_mask"] = std::make_shared(std::move(dev_masks)); + sg.attrs["storage_type"] = std::make_shared(std::move(stypes)); + sg = exec::InferStorageType(std::move(sg)); + CHECK_EQ(sg.GetAttr("storage_type_num_unknown_nodes"), 0U); + return std::move(sg); +} + /*! * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node * and keep the subgraph in the subgraph node. The input entries and output entries @@ -625,6 +709,7 @@ void CreateSubgraphNode(Graph* g, const std::vector& simple_nodes, const std::vector& subgraph_nodes, const size_t subgraph_id, + std::unordered_map* subgraphs, std::unordered_map* entry_top_order_map) { #if DEBUG_SUBGRAPH LOG(INFO) << "Searching for input entries..."; @@ -646,9 +731,19 @@ void CreateSubgraphNode(Graph* g, for (size_t i = 0; i < output_entries.size(); ++i) { sym.outputs[i] = *output_entries[i]; } - const SubgraphPropertyPtr& subg_prop = g->GetAttr("subgraph_property"); - nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id); + const SubgraphPropertyPtr& subg_prop = g->GetAttr("subgraph_property"); + nnvm::NodePtr n; + if (!subg_prop->NeedGraphAttrs()) { + n = subg_prop->CreateSubgraphNode(sym, subgraph_id); + } else { + nnvm::Graph subgraph; + subgraph.outputs = sym.outputs; + // update subgraph attrs + subgraph = InferSubgraphAttrs(g, orig_input_entries, *subgraphs, std::move(subgraph)); + n = subg_prop->CreateSubgraphNode(subgraph, subgraph_id); + } + subgraphs->insert({n.get(), sym}); // Connect the external nodes to the subgraph node. subg_prop->ConnectSubgraphOutputs(n, &output_entries); subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries); @@ -749,13 +844,14 @@ Graph PartitionGraph(Graph&& g) { CreateSimpleGraph(g, &simple_nodes); std::vector> subgraph_nodes; FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes); + std::unordered_map subgraphs; for (size_t i = 0; i < subgraph_nodes.size(); ++i) { #if DEBUG_SUBGRAPH std::set simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end()); CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size()); PrintSubgraph(subgraph_nodes[i]); #endif - CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i, &entry_top_order_map); + CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i, &subgraphs, &entry_top_order_map); } return g; } diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index d115d3498e86..cbba73b7282c 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -104,6 +104,18 @@ class SubgraphProperty { */ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const = 0; + /*! + * \brief Create an nnvm node for a given subgraph using graph with attrs. Here users + * can customize how to execute the operators in the subgraph. + * \param g the graph with attrs to create subgraph node + * \param subgraph_id subgraph id + */ + virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Graph &g, + const int subgraph_id = 0) const { + nnvm::Symbol sym; + sym.outputs = g.outputs; + return CreateSubgraphNode(sym, subgraph_id); + } /*! * \brief Connect subgraph internal output with external output entries. * By default, each output entry will connect to an unique internal output. @@ -128,6 +140,12 @@ class SubgraphProperty { std::vector* orig_input_entries) const { subgraph_node->inputs = *orig_input_entries; } + /*! + * \brief Infer subgraph attrs before creating subgraph node, if needed. + */ + virtual bool NeedGraphAttrs() const { + return false; + } /*! * \brief Set an attr with name in the attr map. */ diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk index 746ee2f096f1..29750ddf75ec 100644 --- a/tests/cpp/unittest.mk +++ b/tests/cpp/unittest.mk @@ -62,7 +62,8 @@ build/tests/cpp/engine/%.o : tests/cpp/engine/%.cc | mkldnn $(CXX) -c -std=c++11 $(TEST_CFLAGS) -I$(GTEST_INC) -o build/tests/cpp/engine/$*.o $(filter %.cc %.a, $^) $(TEST): $(TEST_OBJ) lib/libmxnet.so gtest.a - $(CXX) -std=c++11 $(TEST_CFLAGS) -I$(GTEST_INC) -o $@ $^ $(TEST_LDFLAGS) + $(CXX) -std=c++11 $(TEST_CFLAGS) -I$(GTEST_INC) -o $@ $^ $(TEST_LDFLAGS) \ + $(NGRAPH_LDFLAGS_FOR_CPP_UNIT_TESTS_PROG) runtest: $(TEST) LD_LIBRARY_PATH=$(shell pwd)/lib:$(LD_LIBRARY_PATH) $(TEST) diff --git a/tests/python/ngraph/test_ngraph.py b/tests/python/ngraph/test_ngraph.py new file mode 100644 index 000000000000..8115a2f68c97 --- /dev/null +++ b/tests/python/ngraph/test_ngraph.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import print_function +import numpy as np +import mxnet as mx +import os +import unittest + + +def binary_op_ex(sym, x_shape, y_shape): + np.random.seed(0) + x_npy = np.random.randint(0, 10, size=x_shape).astype(np.float32) + y_npy = np.random.randint(0, 10, size=y_shape).astype(np.float32) + exe = sym.simple_bind(ctx=mx.cpu(), x=x_shape, y=y_shape) + mx_out = exe.forward(is_train=True, x=x_npy, y=y_npy)[0].asnumpy() + exe.backward() + return mx_out + + +def test_broadcast_op_no_head_grad(): + x = mx.symbol.Variable("x") + y = mx.symbol.Variable("y") + z = mx.sym.broadcast_not_equal(x, y) + binary_op_ex(z, (1, 10), (10, 1)) + + +def test_broadcast_mix_logic_op(): + x_shape = (1, 10) + y_shape = (10, 1) + x = mx.symbol.Variable("x") + y = mx.symbol.Variable("y") + z1 = mx.sym.broadcast_mul(x, y) + z2 = mx.sym.broadcast_not_equal(z1, y) + z3 = mx.sym.broadcast_mul(z1, z2) + z4 = mx.sym.broadcast_equal(z1, z3) + z5 = mx.sym.broadcast_not_equal(z3, z4) + z6 = mx.sym.broadcast_mul(z5, z4) + z = mx.sym.broadcast_equal(z6, x) + + binary_op_ex(z, (1, 10), (10, 1)) + +def test_batch_normalized_softmax_grad(): + xpu = mx.cpu() + x = mx.sym.Variable('x') + label = mx.sym.Variable('label') + x_nd = mx.nd.array([[1, 6, 4, 2],[1, 6, 4, 2]], ctx=xpu) + grad_x = mx.nd.zeros((2,4), ctx=xpu) + label_nd = mx.nd.array([1,1], ctx=xpu) + + sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, + use_ignore=False, normalization="batch") + ex = sym.bind(ctx=xpu, args={'x': x_nd, 'label': label_nd}, + args_grad={'x': grad_x}) + + ex.forward(is_train=True) + softmax_out = ex.outputs[0].asnumpy() + expected_softmax_out = [[0.005806628, 0.861780069, 0.116629249, 0.015784052], + [0.005806628, 0.861780069, 0.116629249, 0.015784052]] + assert np.isclose(softmax_out, expected_softmax_out).all() + + ex.backward(is_train=True) + grad_out = ex.grad_arrays[0].asnumpy() + k = int(label_nd[0].asscalar()) + expected_grad_out = np.zeros((2,4)) + expected_grad_out[:, k] = - 1 + assert np.isclose(grad_out , (expected_softmax_out + expected_grad_out) / 2).all() + +@unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/14301") +def test_valid_normalized_softmax_grad(): + xpu = mx.cpu() + x = mx.sym.Variable('x') + label = mx.sym.Variable('label') + x_nd = mx.nd.array([[1, 6, 4, 2],[1, 6, 4, 2]], ctx=xpu) + grad_x = mx.nd.zeros((2,4), ctx=xpu) + label_nd = mx.nd.array([1,1], ctx=xpu) + + sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, + use_ignore=True, normalization="valid") + ex = sym.bind(ctx=xpu, args={'x': x_nd, 'label': label_nd}, + args_grad={'x': grad_x}) + + ex.forward(is_train=True) + softmax_out = ex.outputs[0].asnumpy() + expected_softmax_out = [[0.005806628, 0.861780069, 0.116629249, 0.015784052], + [0.005806628, 0.861780069, 0.116629249, 0.015784052]] + assert np.isclose(softmax_out, expected_softmax_out).all() + + ex.backward(is_train=True) + grad_out = ex.grad_arrays[0].asnumpy() + k = int(label_nd[0].asscalar()) + expected_grad_out = np.zeros((2,4)) + expected_grad_out[:, k] = - 1 + + assert np.isclose(grad_out, (expected_softmax_out + expected_grad_out) + / sum(label_nd.asnumpy() != 0)).all() + +def test_valid_make_loss(): + xpu = mx.cpu() + x = mx.sym.Variable('x') + label = mx.sym.Variable('label') + x_nd = mx.nd.array([[0, 1, 1, 0], + [1, 1, 1, .1]], ctx=xpu) + grad_x = mx.nd.zeros((2,4), ctx=xpu) + label_nd = mx.nd.array([1,1], ctx=xpu) + + sym = mx.sym.MakeLoss(x, normalization="valid", valid_thresh=0.2) + ex = sym.bind(ctx=xpu, args={'x': x_nd, 'label': label_nd}, + args_grad={'x': grad_x}) + + ex.forward(is_train=True) + out = ex.outputs[0].asnumpy() + expected_out = [[0, 1, 1, 0], + [1, 1, 1, .1]] + assert np.isclose(out, expected_out).all() + + ex.backward(is_train=True) + grad_out = ex.grad_arrays[0].asnumpy() + expected_grad_out = np.ones((2,4))/5 + + assert np.isclose(grad_out, expected_grad_out).all() + +def test_stop_gradient(): + v1 = mx.nd.array([[1, 2]]) + v2 = mx.nd.array([[0, 1]]) + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + b_stop_grad = mx.sym.stop_gradient(3 * b) + loss = mx.sym.MakeLoss(b_stop_grad + a) + + executor = loss.simple_bind(ctx=mx.cpu(), a=(1,2), b=(1,2)) + executor.forward(is_train=True, a=v1, b=v2) + assert np.isclose(executor.outputs[0].asnumpy(), [1,5]).all() + executor.backward() + assert np.isclose(executor.grad_arrays[0].asnumpy(), [0,0]).all() + assert np.isclose(executor.grad_arrays[1].asnumpy(), [1,1]).all() + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index 2bc696fd4e43..bb75a00a77ed 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -51,7 +51,7 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): args={'rhs': rhs_arr, 'lhs': lhs_arr}, args_grad={'lhs': lhs_grad, 'rhs': rhs_grad}) - executor.forward() + executor.forward(is_train = True) exec3.forward() exec4.forward() out2 = executor.outputs[0].asnumpy() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 34380dc00314..222df05e7fca 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -860,13 +860,13 @@ def test_export(): module.forward(mx.io.DataBatch([data], None), is_train=False) mod_out, = module.get_outputs() - assert_almost_equal(out.asnumpy(), mod_out.asnumpy()) + assert_almost_equal(out.asnumpy(), mod_out.asnumpy(), atol=1e-5) model2 = gluon.model_zoo.vision.resnet18_v1(prefix='resnet', ctx=ctx) model2.collect_params().load('gluon-0000.params', ctx) out2 = model2(data) - assert_almost_equal(out.asnumpy(), out2.asnumpy()) + assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-5) @with_seed() def test_import(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6bb815066c80..6013aa709b05 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1170,8 +1170,8 @@ def test_rsqrt_cos_sin(): @with_seed() def test_maximum_minimum(): - data1 = mx.symbol.Variable('data') - data2 = mx.symbol.Variable('data') + data1 = mx.symbol.Variable('data1') + data2 = mx.symbol.Variable('data2') shape = (3, 4) data_tmp1 = np.random.rand(3,4) data_tmp2 = np.random.rand(3,4) diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index 40d609ad3541..e39ef0ba4bb4 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -51,8 +51,7 @@ def _check_subgraph_exe1(sym, op_names): partitioned_exe.forward() assert len(exe.outputs) == len(partitioned_exe.outputs) for i in range(len(exe.outputs)): - assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(), - np.zeros(shape=(1,))) + assert_almost_equal(exe.outputs[i].asnumpy(), partitioned_exe.outputs[i].asnumpy()) def _check_subgraph_exe2(sym, op_names): """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in simple_bind @@ -84,7 +83,7 @@ def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None): outputs2 = partitioned_exec.outputs assert len(outputs1) == len(outputs2) for i in range(len(outputs1)): - assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) + assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy()) def _check_subgraph_exe3(sym, op_names): """Use the partitioned sym to bind an executor and compare the outputs