From 10d21503996245f6d99a90d6f6472ddb27bdc3f8 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 14:28:43 -0700 Subject: [PATCH 01/18] Remove numpy namespaces for operator registration --- python/mxnet/__init__.py | 3 +- python/mxnet/base.py | 62 ++++-- python/mxnet/ndarray/__init__.py | 2 +- python/mxnet/ndarray/numpy.py | 18 -- python/mxnet/numpy/__init__.py | 66 ------- python/mxnet/symbol/__init__.py | 2 +- python/mxnet/symbol/numpy.py | 18 -- python/mxnet/symbol/symbol.py | 3 +- src/operator/numpy/np_broadcast_reduce_op.h | 186 ------------------ .../numpy/np_broadcast_reduce_op_value.cc | 61 ------ .../numpy/np_broadcast_reduce_op_value.cu | 36 ---- src/operator/tensor/matrix_op-inl.h | 5 + tests/python/unittest/test_operator.py | 45 ++++- 13 files changed, 96 insertions(+), 411 deletions(-) delete mode 100644 python/mxnet/ndarray/numpy.py delete mode 100644 python/mxnet/numpy/__init__.py delete mode 100644 python/mxnet/symbol/numpy.py delete mode 100644 src/operator/numpy/np_broadcast_reduce_op.h delete mode 100644 src/operator/numpy/np_broadcast_reduce_op_value.cc delete mode 100644 src/operator/numpy/np_broadcast_reduce_op_value.cu diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 8db83a286157..5f4f9b393e41 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -23,9 +23,8 @@ from .context import Context, current_context, cpu, gpu, cpu_pinned from . import engine -from .base import MXNetError +from .base import MXNetError, is_np_comp, set_np_comp, enable_np_comp, disable_np_comp from . import base -from . import numpy from . import contrib from . import ndarray from . import ndarray as nd diff --git a/python/mxnet/base.py b/python/mxnet/base.py index fe1dd00f9454..719b3cb58a6a 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -561,7 +561,7 @@ def _as_list(obj): return [obj] -_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_'] +_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_'] def _get_op_name_prefix(op_name): @@ -607,13 +607,6 @@ def _init_op_module(root_namespace, module_name, make_op_func): # use mx.nd.contrib or mx.sym.contrib from now on contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name) contrib_module_old = sys.modules[contrib_module_name_old] - # special handling of registering numpy ops - if module_name == 'ndarray': - numpy_module_name = "%s.numpy" % root_namespace - numpy_module = sys.modules[numpy_module_name] - else: - numpy_module_name = None - numpy_module = None submodule_dict = {} for op_name_prefix in _OP_NAME_PREFIX_LIST: submodule_dict[op_name_prefix] =\ @@ -652,16 +645,6 @@ def _init_op_module(root_namespace, module_name, make_op_func): function.__module__ = contrib_module_name_old setattr(contrib_module_old, function.__name__, function) contrib_module_old.__all__.append(function.__name__) - elif op_name_prefix == '_numpy_' and numpy_module_name is not None: - # only register numpy ops under mxnet.numpy in imperative mode - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - # TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random. - func_name = name[len(op_name_prefix):] - function = make_op_func(hdl, name, func_name) - function.__module__ = numpy_module_name - setattr(numpy_module, function.__name__, function) - numpy_module.__all__.append(function.__name__) def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func): @@ -751,3 +734,46 @@ def write_all_str(module_file, module_all_list): ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p + + +def set_np_comp(is_np_comp): + prev = ctypes.c_int() + check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(is_np_comp), ctypes.byref(prev))) + return bool(prev.value) + + +def is_np_comp(): + curr = ctypes.c_bool() + check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr))) + return curr.value + + +class _NumpyCompatibilityStateScope(object): + """Scope for managing numpy compatibility state. + + Example:: + + with _NumpyCompatibilityStateScope(True): + y = model(x) + backward([y]) + + """ + def __init__(self, is_np_comp): #pylint: disable=redefined-outer-name + self._enter_is_np_comp = is_np_comp + self._prev_is_np_comp = None + + def __enter__(self): + if self._enter_is_np_comp is not None: + self._prev_is_np_comp = set_np_comp(self._enter_is_np_comp) + + def __exit__(self, ptype, value, trace): + if self._enter_is_np_comp is not None and self._prev_is_np_comp != self._enter_is_np_comp: + set_np_comp(self._prev_is_np_comp) + + +def enable_np_comp(): + return _NumpyCompatibilityStateScope(True) + + +def disable_np_comp(): + return _NumpyCompatibilityStateScope(False) diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index a102399521cc..f09908e894d5 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -17,7 +17,7 @@ """NDArray API of MXNet.""" -from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy +from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray # pylint: disable=wildcard-import, redefined-builtin try: from .gen_op import * # pylint: disable=unused-wildcard-import diff --git a/python/mxnet/ndarray/numpy.py b/python/mxnet/ndarray/numpy.py deleted file mode 100644 index 0826ac8aca7f..000000000000 --- a/python/mxnet/ndarray/numpy.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -__all__ = [] diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py deleted file mode 100644 index e0dfda10113e..000000000000 --- a/python/mxnet/numpy/__init__.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import ctypes -from ..base import _LIB, check_call - -__all__ = [] - - -def set_np_comp(is_np_comp): - prev = ctypes.c_int() - check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(is_np_comp), ctypes.byref(prev))) - return bool(prev.value) - - -def is_np_comp(): - curr = ctypes.c_bool() - check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr))) - return curr.value - - -class _NumpyCompatibilityStateScope(object): - """Scope for managing numpy compatibility state. - - Example:: - - with _NumpyCompatibilityStateScope(True): - y = model(x) - backward([y]) - - """ - def __init__(self, is_np_comp): #pylint: disable=redefined-outer-name - self._enter_is_np_comp = is_np_comp - self._prev_is_np_comp = None - - def __enter__(self): - if self._enter_is_np_comp is not None: - self._prev_is_np_comp = set_np_comp(self._enter_is_np_comp) - - def __exit__(self, ptype, value, trace): - if self._enter_is_np_comp is not None and self._prev_is_np_comp != self._enter_is_np_comp: - set_np_comp(self._prev_is_np_comp) - - -def enable_np_comp(): - return _NumpyCompatibilityStateScope(True) - - -def disable_np_comp(): - return _NumpyCompatibilityStateScope(False) diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py index 326e4f5aff78..f438e4954aa9 100644 --- a/python/mxnet/symbol/__init__.py +++ b/python/mxnet/symbol/__init__.py @@ -17,7 +17,7 @@ """Symbol API of MXNet.""" -from . import _internal, contrib, linalg, op, random, sparse, image, symbol, numpy +from . import _internal, contrib, linalg, op, random, sparse, image, symbol # pylint: disable=wildcard-import, redefined-builtin try: from .gen_op import * # pylint: disable=unused-wildcard-import diff --git a/python/mxnet/symbol/numpy.py b/python/mxnet/symbol/numpy.py deleted file mode 100644 index 0826ac8aca7f..000000000000 --- a/python/mxnet/symbol/numpy.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -__all__ = [] diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 73d4babe6f27..fde540a2682a 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -34,7 +34,7 @@ from ..attribute import AttrScope from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array -from ..base import mx_uint, py_str, string_types, integer_types, mx_int +from ..base import mx_uint, py_str, string_types, integer_types, mx_int, is_np_comp from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context @@ -42,7 +42,6 @@ from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from ..ndarray import _ndarray_cls from ..executor import Executor -from ..numpy import is_np_comp from . import _internal from . import op from ._internal import SymbolBase, _set_symbol_class diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h deleted file mode 100644 index e0379a040c3f..000000000000 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file broadcast_reduce_op.h - * \brief Function definition of broadcast and reduce operators - */ -#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ -#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ - -#include "../tensor/broadcast_reduce_op.h" - -namespace mxnet { -namespace op { - -struct NumpyReduceAxesParam : public dmlc::Parameter { - dmlc::optional> axis; - dmlc::optional dtype; - bool keepdims; - dmlc::optional initial; - DMLC_DECLARE_PARAMETER(NumpyReduceAxesParam) { - DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional>()) - .describe(R"code()code"); - DMLC_DECLARE_FIELD(dtype).set_default(dmlc::optional()) - .describe(""); - DMLC_DECLARE_FIELD(keepdims).set_default(false) - .describe("If this is set to `True`, the reduced axes are left " - "in the result as dimension with size one."); - } -}; - -inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape, - const dmlc::optional>& axis, - bool keepdims) { - // TODO(junwu): improve the logic - // If input is a scalar, output should be a scalar too - if (ishape.ndim() == 0) { - if (axis.has_value()) { - const nnvm::Tuple& axes = axis.value(); - if (axes.ndim() > 0) { - CHECK_EQ(axes.ndim(), 1); - CHECK(axes[0] == 0 || axes[0] == -1); - } - } - return TShape(0, -1); - } - - // axis=None, do global reduction - if (!axis.has_value()) { - if (keepdims) { - return TShape(ishape.ndim(), 1); - } else { - return TShape(0, -1); - } - } - - // axis = (), will return identity(input) - if (axis.value().ndim() == 0) { - return ishape; - } - - // axis has value - nnvm::Tuple axes(axis.value()); - for (index_t i = 0; i < axes.ndim(); i++) { - if (axes[i] < 0) { - axes[i] += ishape.ndim(); - } - } - std::sort(axes.begin(), axes.end()); - - for (index_t i = 1; i < axes.ndim(); i++) { - CHECK_LT(axes[i-1], axes[i]) - << "Reduction axes have duplicates " - << axes; - } - CHECK_LT(axes[axes.ndim()-1], ishape.ndim()) - << "Reduction axis " << axes[axes.ndim()-1] - << " Exceeds input dimensions " << ishape; - CHECK_GE(axes[0], 0) - << "Reduction axis " << axis.value() - << " Exceeds input dimensions " << ishape; - - TShape oshape; - if (keepdims) { - oshape = TShape(ishape); - } else { - oshape = TShape(ishape.ndim() - axes.ndim(), -1); - } - - if (keepdims) { - for (index_t i = 0; i < axes.ndim(); ++i) { - oshape[axes[i]] = 1; - } - } else { - for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { - if (j < axes.ndim() && i == axes[j]) { - ++j; - continue; - } - oshape[k++] = ishape[i]; - } - } - return oshape; -} - -inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - if (!shape_is_known(in_attrs->at(0))) { - return false; - } - const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, - NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims)); - return shape_is_known(out_attrs->at(0)); -} - -template -void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); - if (param.axis.has_value() && param.axis.value().ndim() == 0) { - UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); - } - TShape small; - if (param.keepdims) { - small = outputs[0].shape_; - } else { - small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); - } - - ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); -} - -template -inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); - TShape small; - if (param.keepdims) { - small = inputs[0].shape_; - } else { - small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); - } - - BroadcastComputeImpl(attrs, ctx, inputs, req, outputs, small); - if (normalize) { - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor igrad = outputs[0].FlatTo1D(s); - igrad /= scalar(outputs[0].Size()/inputs[0].Size()); - }); - } -} - -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc deleted file mode 100644 index c028e2368737..000000000000 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file np_reduce_op_value.cc - * \brief CPU Implementation of broadcast and reduce functions based on value. - */ - -#include "np_broadcast_reduce_op.h" - -namespace mxnet { -namespace op { - -DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam); - -NNVM_REGISTER_OP(_numpy_sum) -.describe(R"code()code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", NumpyReduceAxesShape) -.set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"a"}; - }) -.add_argument("a", "NDArray-or-Symbol", "The input") -.add_arguments(NumpyReduceAxesParam::__FIELDS__()) -.set_attr("FCompute", NumpyReduceAxesCompute) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}); - -NNVM_REGISTER_OP(_backward_numpy_sum) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_num_inputs(1) -.set_attr("FCompute", NumpyReduceAxesBackwardUseNone); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu deleted file mode 100644 index c975b18226db..000000000000 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file np_reduce_op_value.cu - * \brief GPU Implementation of reduce functions based on value. - */ -#include "np_broadcast_reduce_op.h" - -namespace mxnet { -namespace op { -NNVM_REGISTER_OP(_numpy_sum) -.set_attr("FCompute", NumpyReduceAxesCompute); - -NNVM_REGISTER_OP(_backward_numpy_sum) -.set_attr("FCompute", NumpyReduceAxesBackwardUseNone); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index efc8289a8174..4088b3851a5a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -1154,9 +1154,14 @@ inline bool SliceAxisShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) return false; int axis; index_t begin, end; GetSliceAxisParams(param, ishape, &axis, &begin, &end); + if (!mxnet::dim_size_is_known(ishape, axis)) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape); + return false; + } mxnet::TShape shape(ishape.ndim(), -1); for (int i = 0; i < ishape.ndim(); ++i) { if (static_cast(i) == axis) { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 985405424dda..3d2291870a3c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4398,7 +4398,7 @@ def test_invalid_reps(): assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3)) test_normal_case() - with mx.numpy.enable_np_comp(): + with mx.enable_np_comp(): test_empty_tensor() test_empty_reps() test_tile_backward() @@ -4459,7 +4459,7 @@ def test_zero_depth(): test_normal_case(index_type=np.float64) test_normal_case(index_type=np.float32) test_normal_case(index_type=np.float16) - with mx.numpy.enable_np_comp(): + with mx.enable_np_comp(): test_empty_indices() test_zero_depth() @@ -6840,6 +6840,20 @@ def check_slice_axis_partial_infer(data, axis, begin, end, expected_out_shape): check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0)) check_slice_axis_partial_infer(var1, 1, 0, 5, (10, 0)) + with mx.enable_np_comp(): + var1 = mx.sym.var(name="data", shape=(-1, 20)) + check_slice_partial_infer(var1, (None, None), (None, 10), [], (-1, 10)) + check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (-1, 5)) + check_slice_partial_infer(var1, (None, 3), (None, 10), [], (-1, 7)) + check_slice_partial_infer(var1, (None, 3), (5, 10), [], (-1, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), [], (-1, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), (None, 1), (-1, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), (3, 3), (-1, 3)) + + var1 = mx.sym.var(name='data', shape=(10, -1)) + check_slice_axis_partial_infer(var1, 0, 0, 5, (5, -1)) + check_slice_axis_partial_infer(var1, 1, 0, 5, (10, -1)) + @with_seed() def test_float16_min_max(): @@ -7878,6 +7892,33 @@ def test_image_normalize(): check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) +def test_scalar_tensor_creation(): + assertRaises(MXNetError, mx.nd.zeros, shape=()) + assertRaises(MXNetError, mx.nd.ones, shape=()) + with mx.enable_np_comp(): + data_mx = mx.nd.ones(shape=()) + data_np = np.ones((), dtype=data_mx.dtype) + assert same(data_mx.asnumpy(), data_np) + + +def test_zero_size_tensor_creation(): + assertRaises(MXNetError, mx.nd.zeros, shape=(0, 1, 3, 0)) + assertRaises(MXNetError, mx.nd.ones, shape=(0, 1, 3, 0)) + with mx.enable_np_comp(): + data_mx = mx.nd.ones(shape=(0, 1, 0, 4)) + data_np = np.ones(shape=data_mx.shape, dtype=data_mx.dtype) + assert same(data_mx.asnumpy(), data_np) + + +def test_concat_with_zero_size_tensor(): + with mx.enable_np_comp(): + data1 = mx.nd.ones((0, 8, 12)) + data2 = mx.nd.ones((3, 8, 12)) + data3 = mx.nd.ones((0, 8, 12)) + ret = mx.nd.Concat(data1, data2, data3, dim=0) + assert ret.shape == (3, 8, 12) + + if __name__ == '__main__': import nose nose.runmodule() From 4c2d34b91aa6617e26fe4918e2b55306f54d22c3 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 15:30:13 -0700 Subject: [PATCH 02/18] Fix bug when shape is compeltely unknown --- python/mxnet/base.py | 73 +++++++++++++++++++++++ src/c_api/c_api_common.h | 4 +- tests/python/unittest/test_infer_shape.py | 16 +++++ 3 files changed, 92 insertions(+), 1 deletion(-) diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 719b3cb58a6a..bef0950f7a13 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -737,12 +737,32 @@ def write_all_str(module_file, module_all_list): def set_np_comp(is_np_comp): + """ + Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend. + + Parameters + ---------- + is_np_comp : bool + Indicates whether to turn on/off NumPy compatibility. + + Returns + ------- + A bool value indicating the previous state of NumPy compatibility. + """ prev = ctypes.c_int() check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(is_np_comp), ctypes.byref(prev))) return bool(prev.value) def is_np_comp(): + """ + Checks whether the NumPy compatibility is currently turned on. + NumPy-compatibility is turned off by default in backend. + + Returns + ------- + A bool value indicating whether the NumPy compatibility is currently on. + """ curr = ctypes.c_bool() check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr))) return curr.value @@ -772,8 +792,61 @@ def __exit__(self, ptype, value, trace): def enable_np_comp(): + """Returns a NumPy compatibility state scope to be used in 'with' statement + and captures code that needs the compatibility. + + Example:: + + with mx.enable_np_comp(): + # A scalar tensor's shape is `()`, whose `ndim` is `0`. + scalar = mx.nd.ones(shape=()) + assert scalar.shape == () + + # In NumPy compatible mode, 0 in a shape means that dimension contains zero elements. + data = mx.sym.var("data", shape=(0, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape() + assert arg_shapes[0] == (0, 2, 3) + assert out_shapes[0] == (0, 2, 3) + + # -1 means unknown shape dimension size in the new NumPy-compatible shape definition + data = mx.sym.var("data", shape=(-1, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == (-1, 2, 3) + assert out_shapes[0] == (-1, 2, 3) + + # When a shape is completely unknown in NumPy-compatible mode, it is + # represented as `None` in Python. + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] is None + assert out_shapes[0] is None + """ return _NumpyCompatibilityStateScope(True) def disable_np_comp(): + """Returns a state scope with NumPy-compatibility disabled to be used in 'with' statement + and captures code that does not need the compatibility. + + Example:: + + with mx.disable_np_comp(): + # 0 means unknown shape dimension size in the legacy shape definition. + data = mx.sym.var("data", shape=(0, 2, 3)) + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == (0, 2, 3) + assert out_shapes[0] == (0, 2, 3) + + # When a shape is completely unknown in the legacy mode (default), its ndim is + # equal to 0 and it is represented as `()` in Python. + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == () + assert out_shapes[0] == () + """ return _NumpyCompatibilityStateScope(False) diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 55608b950866..8be192b22c53 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -100,7 +100,9 @@ struct MXAPIThreadLocalEntry { for (size_t i = 0; i < shapes.size(); ++i) { ndim->at(i) = shapes[i].ndim(); data->at(i) = ptr; - ptr = mxnet::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); + if (shapes[i].ndim() > 0) { + ptr = mxnet::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); + } } } }; diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 73654a604135..e0b4d35ea9aa 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -147,6 +147,21 @@ def test_fc_infer_type(): assert arg_type_dict[k] == v +def test_shape_completely_unknown(): + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] == () + assert out_shapes[0] == () + + with mx.enable_np_comp(): + data = mx.sym.var("data") + ret = mx.sym.sin(data) + arg_shapes, out_shapes, _ = ret.infer_shape_partial() + assert arg_shapes[0] is None + assert out_shapes[0] is None + + if __name__ == "__main__": test_mlp2_infer_shape() test_mlp2_infer_error() @@ -156,3 +171,4 @@ def test_fc_infer_type(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() + test_shape_completely_unknown() From 8dfd96546235f784ec0d369f1cd6cf6012caf1d5 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 15:40:42 -0700 Subject: [PATCH 03/18] Fix singed/unsigned compare warning --- src/operator/contrib/index_copy-inl.h | 2 +- src/operator/pad-inl.h | 2 +- src/operator/swapaxis-inl.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h index 35f88916da20..9f78f0593ed1 100644 --- a/src/operator/contrib/index_copy-inl.h +++ b/src/operator/contrib/index_copy-inl.h @@ -64,7 +64,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->at(1).ndim(), 1); // Shape matching CHECK_EQ(in_attrs->at(0).ndim(), in_attrs->at(2).ndim()); - for (size_t i = 0; i < in_attrs->at(0).ndim(); ++i) { + for (int i = 0; i < in_attrs->at(0).ndim(); ++i) { if (i == 0) { CHECK_GE(in_attrs->at(0)[i], in_attrs->at(2)[i]); } else { diff --git a/src/operator/pad-inl.h b/src/operator/pad-inl.h index 140d7099e817..89b0ab7780b6 100644 --- a/src/operator/pad-inl.h +++ b/src/operator/pad-inl.h @@ -230,7 +230,7 @@ class PadProp : public OperatorProperty { } } mxnet::TShape oshape = dshape; - for (size_t i = 0; i < dshape.ndim(); ++i) { + for (int i = 0; i < dshape.ndim(); ++i) { oshape[i] = param_.pad_width[2 * i] + param_.pad_width[2 * i + 1] + dshape[i]; } diff --git a/src/operator/swapaxis-inl.h b/src/operator/swapaxis-inl.h index 41cb940d957a..7335daa48392 100644 --- a/src/operator/swapaxis-inl.h +++ b/src/operator/swapaxis-inl.h @@ -69,7 +69,7 @@ class SwapAxisOp : public Operator { void Reshape2Five(mshadow::Shape<5> *inter_shape, const mxnet::TShape &shape, - uint32_t dim1, uint32_t dim2) { + int dim1, int dim2) { using namespace mshadow; using namespace mshadow::expr; int ndim_in = shape.ndim(); From 439e1771d2cc005b797f11cac2106e90db0bd2d4 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 16:13:07 -0700 Subject: [PATCH 04/18] Fix CI --- src/operator/contrib/bounding_box-inl.h | 3 +-- src/operator/nn/concat.cc | 3 +-- src/operator/nn/deconvolution-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_slice.cc | 4 ++-- src/operator/tensor/matrix_op-inl.h | 3 ++- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 6ea4e8097b6c..686f1666a310 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -94,9 +94,8 @@ inline bool BoxNMSShape(const nnvm::NodeAttrs& attrs, const BoxNMSParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 2U); - // TODO(@junrushao1994): verify with Joshua Z. Zhang about this operator if (mxnet::op::shape_is_none(in_attrs->at(0)) - && mxnet::op::shape_is_none(out_attrs->at(0))) { + && mxnet::op::shape_is_none(out_attrs->at(0))) { return false; } diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index b534ee58e85c..411773b41f94 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -235,8 +235,7 @@ bool SupportMKLDNNConcat(const std::vector &arrs) { if (arr.IsView()) return false; if (arr.dtype() != mshadow::kFloat32) return false; int ndim = arr.shape().ndim(); - unsigned mkldnn_ndims = - static_cast(arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims); + const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims; if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false; } return true; diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index e82a073ea08d..1eeccb02e030 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -143,7 +143,7 @@ struct DeconvolutionParam : public dmlc::Parameter { } } } else { - for (int i = 0; i < (int) ndim; i++) { + for (int i = 0; i < static_cast(ndim); i++) { o_pad[i] = i < pad.ndim() ? pad[i] : 0; o_adj[i] = i < adj.ndim() ? adj[i] : 0; } diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index 96a8afdab6e2..2a817a25a5b8 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -37,10 +37,10 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, const NDArray &out) { const mxnet::TShape ishape = in.shape(); const mxnet::TShape oshape = out.shape(); - uint32_t N = ishape.ndim(); + const int N = ishape.ndim(); mkldnn::memory::dims dims(N); mkldnn::memory::dims offsets(N); - for (uint32_t i = 0; i < N; ++i) { + for (int i = 0; i < N; ++i) { int s = 0; if (i < param.begin.ndim() && param.begin[i]) { s = *param.begin[i]; diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4088b3851a5a..44dbd78f0c0b 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -209,7 +209,8 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs, oshape[inf_idx] = dshape.Size() / oshape.Size(); } } else { - return shape_is_known((*out_attrs)[0]) && ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); + return shape_is_known((*out_attrs)[0]) + && ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); } ReverseReshapeInferShape(&dshape, oshape); #if 0 From 51713e8343e1f39de73f353922470d2006f7c74e Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 16:35:52 -0700 Subject: [PATCH 05/18] Fix pylint --- python/mxnet/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/base.py b/python/mxnet/base.py index bef0950f7a13..916e74182f94 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -736,13 +736,13 @@ def write_all_str(module_file, module_all_list): ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p -def set_np_comp(is_np_comp): +def set_np_comp(flag): """ Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend. Parameters ---------- - is_np_comp : bool + flag : bool Indicates whether to turn on/off NumPy compatibility. Returns @@ -750,7 +750,7 @@ def set_np_comp(is_np_comp): A bool value indicating the previous state of NumPy compatibility. """ prev = ctypes.c_int() - check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(is_np_comp), ctypes.byref(prev))) + check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(flag), ctypes.byref(prev))) return bool(prev.value) From 6a13bcaeb8717cf00a881a0cdd458f07df9d88dc Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 21:32:03 -0700 Subject: [PATCH 06/18] Avoid launching gpu kernels for zero-size output tensors --- src/operator/mxnet_op.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d8fc5031e4ff..636c64d94c34 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -714,6 +714,7 @@ struct Kernel { /*! \brief Launch GPU kernel */ template inline static void Launch(mshadow::Stream *s, int N, Args... args) { + if (0 == N) return; using namespace mshadow::cuda; int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); mxnet_generic_kernel @@ -724,6 +725,7 @@ struct Kernel { template inline static void LaunchEx(mshadow::Stream *s, const int N, Args... args) { + if (0 == N) return; using namespace mshadow::cuda; int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); mxnet_generic_kernel_ex From d8fbba39d50e1d7226cafdccd8ddaef99aa80b5b Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 4 Apr 2019 22:28:49 -0700 Subject: [PATCH 07/18] Fix test_ndarray --- tests/python/unittest/test_ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 1bdb7d51df67..f8f52c2cee24 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -123,7 +123,7 @@ def test_ndarray_setitem(): # numpy assignment for empty axis for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]: if trivial_shape == tuple(): - with mx.numpy.enable_np_comp(): + with mx.enable_np_comp(): x = mx.nd.zeros(trivial_shape) else: x = mx.nd.zeros(trivial_shape) From 9c4e2082cb9081c9285fcdc44cbcd2aa5eea9ae5 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 5 Apr 2019 13:57:35 -0700 Subject: [PATCH 08/18] Fix binary broadcast with zero-size tensors --- .../tensor/elemwise_binary_broadcast_op.h | 15 +++++-------- tests/python/gpu/test_operator_gpu.py | 22 ++++++++++--------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 64a4d7cc15ff..73019fa8389b 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -61,16 +61,13 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, int l = 1, r = 1; if (i >= bl) l = lhs[i-bl]; if (i >= br) r = rhs[i-br]; + if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) continue; if (l != r) { - if (l == 0 || r == 0) { - // TODO(junwu): here is not compatible with NumPy. - // For example, (2, 3) cannot broadcast to (2, 0, 3). - out[i] = 0; - } else { - CHECK(l == 1 || r == 1) - << "operands could not be broadcast together with shapes " << lhs << " " << rhs; - out[i] = std::max(l, r); - } + // Make it compatible with NumPy. + // For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can broadcast to (2, 0, 3). + CHECK(l == 1 || r == 1) + << "operands could not be broadcast together with shapes " << lhs << " " << rhs; + out[i] = (l == 1 ? r : l); } else { out[i] = l; } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index fbbfc53a9a5e..f8ebe9517bac 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1963,19 +1963,21 @@ def check_proposal_consistency(op, batch_size, with_nms=False): # The following 2 functions launch 0-thread kernels, an error that should be caught and signaled. def kernel_error_check_imperative(): os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) - b = mx.nd.array([],ctx=mx.gpu(0)) - c = (a / b).asnumpy() + with mx.enable_np_comp(): + a = mx.nd.array([1,2,3],ctx=mx.gpu(0)) + b = mx.nd.array([],ctx=mx.gpu(0)) + c = (a / b).asnumpy() def kernel_error_check_symbolic(): os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - a = mx.sym.Variable('a') - b = mx.sym.Variable('b') - c = a / b - f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), - 'b':mx.nd.array([],ctx=mx.gpu(0))}) - f.forward() - g = f.outputs[0].asnumpy() + with mx.enable_np_comp(): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = a / b + f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)), + 'b':mx.nd.array([],ctx=mx.gpu(0))}) + f.forward() + g = f.outputs[0].asnumpy() def test_kernel_error_checking(): # Running tests that may throw exceptions out of worker threads will stop CI testing From 5a96f97389870b7c262a4561c5c3bc6bc2802ea9 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 5 Apr 2019 14:55:53 -0700 Subject: [PATCH 09/18] Better error message for infer shape failure in imperative --- src/imperative/imperative_utils.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index d058df4b3806..24e6c49d5edb 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -131,12 +131,16 @@ inline void SetShapeType(const Context& ctx, std::stringstream os; os << "Operator " << attrs.op->name << " inferring shapes failed.\n"; os << "input shapes:\n"; - for (auto& nd : inputs) { - os << nd->shape() << '\n'; + for (const auto& s : in_shapes) { + os << s << '\n'; } os << "output shapes:\n"; - for (auto& nd : outputs) { - os << nd->shape() << '\n'; + for (const auto& s : out_shapes) { + os << s << '\n'; + } + os << "operator attributes:\n"; + for (const auto& kv : attrs.dict) { + os << kv.first << " : " << kv.second << '\n'; } LOG(FATAL) << os.str(); } From a5ac53149612582a69f60960cf4b2b2bd1a80f43 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 5 Apr 2019 22:49:58 -0700 Subject: [PATCH 10/18] Fix TShape constructor ambiguity on certain platforms --- include/mxnet/tuple.h | 11 ++++++++--- src/common/utils.h | 2 +- src/operator/contrib/deformable_convolution-inl.h | 6 +++--- src/operator/pooling_v1-inl.h | 6 +++--- src/operator/quantization/dequantize-inl.h | 2 +- .../quantization/mkldnn/mkldnn_requantize-inl.h | 2 +- src/operator/quantization/quantize-inl.h | 2 +- src/operator/quantization/quantize_v2-inl.h | 2 +- src/operator/quantization/quantized_conv.cc | 4 ++-- .../quantization/quantized_fully_connected.cc | 4 ++-- src/operator/quantization/requantize-inl.h | 2 +- src/operator/tensor/broadcast_reduce_op.h | 6 +++--- src/operator/tensor/histogram-inl.h | 8 ++++---- src/operator/tensor/matrix_op-inl.h | 4 ++-- tests/cpp/misc/serialization.cc | 2 +- 15 files changed, 34 insertions(+), 29 deletions(-) diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index c5a358628ccd..de18eea3e007 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -390,7 +390,7 @@ class TShape : public Tuple { * \param ndim the number of dimension * \param value the dimension size for all dims */ - inline TShape(int ndim, int value = -1) { // NOLINT(*) + inline TShape(const int ndim, const dim_t value) { // NOLINT(*) this->SetDim(ndim); if (ndim > 0) { std::fill_n(begin(), ndim, value); @@ -422,12 +422,17 @@ class TShape : public Tuple { this->swap(s); } /*! - * \brief construct the Tuple from content of iterator + * \brief construct the Tuple from content of iterator. + * This function is enforced with template arguments of random access iterator types. + * This is necessary to distinguish from another constructor: TShape(const int, const dim_t). * \param begin the beginning of iterator * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template + template::iterator_category, + std::random_access_iterator_tag>::value, int>::type = 0> inline TShape(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); diff --git a/src/common/utils.h b/src/common/utils.h index 4fb398d883a6..6cdb869ff9ae 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -776,7 +776,7 @@ inline void ConvertToNumpyShape(mxnet::ShapeVector* shapes) { */ inline void ConvertToLegacyShape(mxnet::TShape* shape) { if (!mxnet::ndim_is_known(*shape)) { - *shape = mxnet::TShape(0); + *shape = mxnet::TShape(0, -1); } else { for (int j = 0; j < shape->ndim(); ++j) { if (!mxnet::dim_size_is_known(*shape, j)) { diff --git a/src/operator/contrib/deformable_convolution-inl.h b/src/operator/contrib/deformable_convolution-inl.h index a7e22f548151..000d703066d7 100644 --- a/src/operator/contrib/deformable_convolution-inl.h +++ b/src/operator/contrib/deformable_convolution-inl.h @@ -69,11 +69,11 @@ struct DeformableConvolutionParam : public dmlc::Parameter layout; DMLC_DECLARE_PARAMETER(DeformableConvolutionParam) { DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, -1)) .describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, -1)) .describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, -1)) .describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Convolution filter(channel) number"); diff --git a/src/operator/pooling_v1-inl.h b/src/operator/pooling_v1-inl.h index 22a166cbb6cc..4241b08a0c5e 100644 --- a/src/operator/pooling_v1-inl.h +++ b/src/operator/pooling_v1-inl.h @@ -55,7 +55,7 @@ struct PoolingV1Param : public dmlc::Parameter { int pooling_convention; bool global_pool; DMLC_DECLARE_PARAMETER(PoolingV1Param) { - DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape(0, -1)) .enforce_nonzero() .describe("pooling kernel size: (y, x) or (d, y, x)"); @@ -73,11 +73,11 @@ struct PoolingV1Param : public dmlc::Parameter { .add_enum("valid", pool_v1_enum::kValid) .describe("Pooling convention to be applied."); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, -1)) .enforce_nonzero() .describe("stride: for pooling (y, x) or (d, y, x)"); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, -1)) .describe("pad for pooling: (y, x) or (d, y, x)"); } }; diff --git a/src/operator/quantization/dequantize-inl.h b/src/operator/quantization/dequantize-inl.h index 88199bc2591d..7c91ad507fd9 100644 --- a/src/operator/quantization/dequantize-inl.h +++ b/src/operator/quantization/dequantize-inl.h @@ -99,7 +99,7 @@ inline bool DequantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); for (size_t i = 1; i < 3; ++i) { - SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1)); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h index 45713589dd48..ac414c72d51a 100644 --- a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h @@ -115,7 +115,7 @@ static void MKLDNNRequantizeForward(const nnvm::NodeAttrs& attrs, const size_t actual_float_size = sizeof(float); const size_t actual_quantized_size = sizeof(SrcDType); const size_t temp_reduce_size = ConfigReduce(s, - inputs[0].shape(), mxnet::TShape({1}), &src_shape, &dst_shape); + inputs[0].shape(), mxnet::TShape(1, 1), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s); diff --git a/src/operator/quantization/quantize-inl.h b/src/operator/quantization/quantize-inl.h index 2c267a76a571..7b856579a7b5 100644 --- a/src/operator/quantization/quantize-inl.h +++ b/src/operator/quantization/quantize-inl.h @@ -120,7 +120,7 @@ inline bool QuantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 3U); for (size_t i = 1; i < 3; ++i) { - SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1)); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 02ace6c39fac..9ebb645e1ba6 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -175,7 +175,7 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, mxnet::TShape src_shape, dst_shape; const size_t actual_float_size = sizeof(float); const size_t temp_reduce_size = ConfigReduce( - s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( Shape1(2 * actual_float_size + temp_reduce_size), s); const int dev_id = ctx.run_ctx.ctx.dev_id; diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc index 1a801ee50744..aa3f5ce1ad61 100644 --- a/src/operator/quantization/quantized_conv.cc +++ b/src/operator/quantization/quantized_conv.cc @@ -78,8 +78,8 @@ bool QuantizedConvShape(const nnvm::NodeAttrs& attrs, oshape[W] = (AddPad(dshape[W], param.pad[1]) - wshape[W]) / param.stride[1] + 1; SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); - SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); - SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1)); return true; } diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index cc4365f818d2..e42ea3020352 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -75,8 +75,8 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, } else { SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden)); } - SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); - SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1)); return true; } diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h index 21d58d4607eb..9106c7fe4716 100644 --- a/src/operator/quantization/requantize-inl.h +++ b/src/operator/quantization/requantize-inl.h @@ -111,7 +111,7 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs, const size_t actual_float_size = sizeof(float); const size_t actual_quantized_size = sizeof(SrcDType); const size_t temp_reduce_size = ConfigReduce( - s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); Tensor temp_space = ctx.requested[0].get_space_typed( Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s); diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 2d432c9563d6..555fa5ec5c4c 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -129,9 +129,9 @@ struct BroadcastAxesParam : public dmlc::Parameter { mxnet::TShape axis; mxnet::TShape size; DMLC_DECLARE_PARAMETER(BroadcastAxesParam) { - DMLC_DECLARE_FIELD(axis).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(axis).set_default(mxnet::TShape(0, -1)) .describe("The axes to perform the broadcasting."); - DMLC_DECLARE_FIELD(size).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(size).set_default(mxnet::TShape(0, -1)) .describe("Target sizes of the broadcasting axes."); } }; @@ -139,7 +139,7 @@ struct BroadcastAxesParam : public dmlc::Parameter { struct BroadcastToParam : public dmlc::Parameter { mxnet::TShape shape; DMLC_DECLARE_PARAMETER(BroadcastToParam) { - DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape(0, -1)) .describe("The shape of the desired array." " We can set the dim to zero if it's same as the original." " E.g `A = broadcast_to(B, shape=(10, 0, 0))` " diff --git a/src/operator/tensor/histogram-inl.h b/src/operator/tensor/histogram-inl.h index 9cf9c490bba2..7194445d7b52 100644 --- a/src/operator/tensor/histogram-inl.h +++ b/src/operator/tensor/histogram-inl.h @@ -86,9 +86,9 @@ inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs, if (has_cnt) { // if cnt is specified, the output histogram has shape (cnt,) // while output bins has shape (cnt+1,) - const int bin_cnt = param.bin_cnt.value(); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({bin_cnt})); - SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape({bin_cnt + 1})); + const dim_t bin_cnt = param.bin_cnt.value(); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, bin_cnt)); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, bin_cnt + 1)); } else { // if cnt is not specified, the output histogram has shape (bins.Size() - 1) // while output bins has same shape as input bins @@ -97,7 +97,7 @@ inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(oshape.ndim(), 1U) << "bins argument should be an 1D vector"; CHECK_GE(oshape.Size(), 2U) << "number of bounds should be >= 2"; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({(oshape[0] - 1)})); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, oshape[0] - 1)); SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1)); } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 44dbd78f0c0b..12b82fc3ecfe 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -59,7 +59,7 @@ struct ReshapeParam : public dmlc::Parameter { .set_default(false) .describe("If true then the special values are inferred from right to left"); DMLC_DECLARE_FIELD(target_shape) - .set_default(mxnet::TShape(0)) + .set_default(mxnet::TShape(0, -1)) .describe("(Deprecated! Use ``shape`` instead.) " "Target new shape. One and only one dim can be 0, " "in which case it will be inferred from the rest of dims"); @@ -241,7 +241,7 @@ inline bool FlattenShape(const nnvm::NodeAttrs& attrs, struct TransposeParam : public dmlc::Parameter { mxnet::TShape axes; DMLC_DECLARE_PARAMETER(TransposeParam) { - DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(0)) + DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(0, -1)) .describe("Target axis order. By default the axes will be inverted."); } }; diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc index 77014238c2fa..2509a43c27ee 100644 --- a/tests/cpp/misc/serialization.cc +++ b/tests/cpp/misc/serialization.cc @@ -48,7 +48,7 @@ TEST(SerializerTest, OutputMapCorrect) { std::map > output_map; output_map.emplace("output_0", std::make_tuple(1, mxnet::TShape({23, 12, 63, 432}), 0, 1)); output_map.emplace("another_output", std::make_tuple(2, mxnet::TShape({23, 123}), 14, -23)); - output_map.emplace("last_output", std::make_tuple(0, mxnet::TShape({0}), -1, 0)); + output_map.emplace("last_output", std::make_tuple(0, mxnet::TShape(1, 0), -1, 0)); std::string serialized_data; common::Serialize(output_map, &serialized_data); std::map > deserialized_output_map; From 9af65b006bffd56c199dfdf6e923cb745d3b25bc Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 5 Apr 2019 23:12:43 -0700 Subject: [PATCH 11/18] Fix mkldnn build failure --- src/ndarray/ndarray.cc | 4 ++-- tests/cpp/include/test_util.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 604000028bf1..f5aac36a48eb 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -549,7 +549,7 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( // If they have different shapes, we need to reshape the array first. // Since this method will only be used inside an operator, we can call // MKLDNNDataReshape to reshape an array. - mxnet::TShape required_shape(desc2.data.ndims); + mxnet::TShape required_shape(desc2.data.ndims, -1); for (int i = 0; i < desc2.data.ndims; i++) required_shape[i] = desc2.data.dims[i]; NDArray reshaped = MKLDNNDataReshape(required_shape); @@ -575,7 +575,7 @@ NDArray NDArray::Reorder2Default() const { // create new ndarray from mkldnn layout mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); - mxnet::TShape tshape(from_desc.data.ndims); + mxnet::TShape tshape(from_desc.data.ndims, -1); for (int i = 0; i < from_desc.data.ndims; i++) tshape[i] = from_desc.data.dims[i]; NDArray ret(tshape, ctx(), false, dtype()); mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h index e0caddbcd027..b0114e1721ef 100644 --- a/tests/cpp/include/test_util.h +++ b/tests/cpp/include/test_util.h @@ -353,14 +353,14 @@ inline StreamType& print_blob_(const RunContext& ctx, if (dim == 1) { // probably a 1d tensor (mshadow::Tensor is deprecated) - TBlob changed(blob.dptr(), mxnet::TShape(3), blob.dev_mask(), blob.dev_id()); + TBlob changed(blob.dptr(), mxnet::TShape(3, -1), blob.dev_mask(), blob.dev_id()); changed.shape_[0] = 1; changed.shape_[1] = 1; changed.shape_[2] = blob.shape_[0]; return print_blob_(ctx, &os, changed, false, false, add_endl); } else if (dim == 2) { // probably a 2d tensor (mshadow::Tensor is deprecated) - TBlob changed(blob.dptr(), mxnet::TShape(4), blob.dev_mask(), blob.dev_id()); + TBlob changed(blob.dptr(), mxnet::TShape(4, -1), blob.dev_mask(), blob.dev_id()); changed.shape_[0] = 1; changed.shape_[1] = 1; changed.shape_[2] = blob.shape_[0]; From bdcfd1a52cbfb24bf8dafdded9044c34a720cc32 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 5 Apr 2019 23:33:16 -0700 Subject: [PATCH 12/18] Fix build failure in gpu and cpp test --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 +- src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 2 +- tests/cpp/operator/batchnorm_test.cc | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 44d1c3c36e99..679e0cd1057b 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -1016,7 +1016,7 @@ class CuDNNConvolutionOp { template inline Shape Strides(const mxnet::TShape &s) { int ndim = s.ndim(); - mxnet::TShape strides(ndim); + mxnet::TShape strides(ndim, -1); for (int i = 0; i != ndim; ++i) strides[i] = s.ProdShape(i+1, ndim); return strides.get(); diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index f652dd85bd41..adb6caf1c028 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -934,7 +934,7 @@ class CuDNNDeconvolutionOp { template inline Shape Strides(const mxnet::TShape &s) { int ndim = s.ndim(); - mxnet::TShape strides(ndim); + mxnet::TShape strides(ndim, -1); for (int i = 0; i != ndim; ++i) strides[i] = s.ProdShape(i+1, ndim); return strides.get(); diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index d74493a0f7fb..ed0e70b831f1 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -1266,7 +1266,7 @@ static void testSaveAndLoad(const std::vector& dims, ChannelAxisTestData data; data.channel_data_ = inputChannelData; - mxnet::TShape shape(dims.size()); + mxnet::TShape shape(dims.size(), -1); for (size_t i = 0, n = dims.size(); i < n; ++i) { shape[i] = index_t(dims[i]); } @@ -1322,7 +1322,7 @@ static mxnet::TShape MakeShape(const std::vector& shape, } CHECK_LT(channelAxis, shape.size() + 1); const index_t dim = index_t(shape.size()) + 1; - mxnet::TShape newShape(dim); + mxnet::TShape newShape(dim, -1); for (size_t x = 0; x < static_cast(channelAxis); ++x) { newShape[x] = index_t(shape[x]); } From ff1ac1e2956f59d505e889b5a6c6d83ba7680de0 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 6 Apr 2019 00:31:50 -0700 Subject: [PATCH 13/18] Fix gpu cpp test build with mkldnn --- tests/cpp/include/test_mkldnn.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h index a379dab7bf90..f1682772a14a 100644 --- a/tests/cpp/include/test_mkldnn.h +++ b/tests/cpp/include/test_mkldnn.h @@ -49,7 +49,7 @@ inline static mkldnn::memory::primitive_desc GetMemPD(const mxnet::TShape s, int inline static mkldnn::memory::primitive_desc GetExpandedMemPD( mkldnn::memory::primitive_desc pd, float scale, int dim = 0) { CHECK(dim < pd.desc().data.ndims) << "dimension cannot be larger than total dimensions of input"; - mxnet::TShape s(pd.desc().data.ndims); + mxnet::TShape s(pd.desc().data.ndims, -1); for (size_t i = 0; i < pd.desc().data.ndims; i++) s[i] = pd.desc().data.dims[i]; s[dim] = static_cast(s[dim] * scale); @@ -165,7 +165,7 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals std::vector pds; { // 1D - mxnet::TShape s(1); + mxnet::TShape s(1, -1); s[0] = 279936; shapes.push_back(s); pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::x)); @@ -175,7 +175,7 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals } { // 2D - mxnet::TShape s(2); + mxnet::TShape s(2, -1); s[0] = 96; s[1] = 2916; shapes.push_back(s); @@ -187,12 +187,12 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals } { // 4D - mxnet::TShape s1(4); + mxnet::TShape s1(4, -1); s1[0] = 10; s1[1] = 96; s1[2] = 54; s1[3] = 54; shapes.push_back(s1); pds.push_back(GetMemPD(s1, dtype, mkldnn::memory::format::nchw)); - mxnet::TShape s2(4); + mxnet::TShape s2(4, -1); s2[0] = 96; s2[1] = 3; s2[2] = 11; s2[3] = 11; shapes.push_back(s2); pds.push_back(GetMemPD(s2, dtype, mkldnn::memory::format::oihw)); @@ -204,7 +204,7 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals } { // 5D - mxnet::TShape s(5); + mxnet::TShape s(5, -1); s[0] = 96; s[1] = 1; s[2] = 3; s[3] = 11; s[4] = 11; shapes.push_back(s); pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::goihw)); @@ -259,7 +259,7 @@ enum ArrayTypes { inline NDArray CreateKernelNDArray(mxnet::TShape kernel, int num_filters, mxnet::TShape input, bool is_deconv = false) { CHECK_EQ(kernel.ndim(), 2) << "mkldnn only supports 2d filters on 4d inputs"; - mxnet::TShape target_shape(4); + mxnet::TShape target_shape(4, -1); target_shape[0] = is_deconv ? input[1] : num_filters; target_shape[1] = is_deconv ? num_filters : input[1]; target_shape[2] = kernel[0]; @@ -470,7 +470,7 @@ inline std::vector GetTestOutputArrays( in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1), "Reshaped NDArray"); } - mxnet::TShape s(1); + mxnet::TShape s(1, -1); if (types & ArrayTypes::NormalReused) { // Type 5. // Get a reused version. @@ -528,7 +528,7 @@ inline std::vector GetTestOutputArrays( // Type 8, 9. // Get a reused version. - mxnet::TShape s(1); + mxnet::TShape s(1, -1); s[0] = shape.Size(); NDArray arr = NDArray(s, Context()); arr = arr.AsArray(shape, arr.dtype()); From 719e7dc010a6bb36bf4cf7b74d2adbf45bfa9870 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 6 Apr 2019 13:14:00 -0700 Subject: [PATCH 14/18] Fix mkldnn cpp test --- tests/cpp/operator/mkldnn_operator_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc index 559ab5da0ccc..961785dcfc87 100644 --- a/tests/cpp/operator/mkldnn_operator_test.cc +++ b/tests/cpp/operator/mkldnn_operator_test.cc @@ -916,13 +916,13 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards if (in_shape.ndim() < 2) continue; - mxnet::TShape wt_shape(2); + mxnet::TShape wt_shape(2, -1); wt_shape[0] = num_hid; wt_shape[1] = GetFCWeightDim2(in_shape); NDArray weights(wt_shape, Context()); InitDefaultArray(&weights, false); - mxnet::TShape bias_shape(1); + mxnet::TShape bias_shape(1, -1); bias_shape[0] = num_hid; NDArray bias(bias_shape, Context()); InitDefaultArray(&bias, false); @@ -931,7 +931,7 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards inputs[1] = &weights; inputs[2] = &bias; - mxnet::TShape out_shape(2); + mxnet::TShape out_shape(2, -1); out_shape[0] = in_shape[0]; out_shape[1] = num_hid; From d621688119abd94fc28e4030d312db5b29876d9d Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 6 Apr 2019 23:16:27 -0700 Subject: [PATCH 15/18] Fix concatenating zero-size tensors --- include/mxnet/tuple.h | 2 +- src/operator/channel_op_common.h | 4 ++++ src/operator/tensor/init_op.h | 2 ++ tests/python/unittest/test_operator.py | 9 +++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index de18eea3e007..c5c0ccd548df 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -627,7 +627,7 @@ inline bool ndim_is_known(const TShape& x) { } /*! brief check if a shape's dim size is known. */ -inline bool dim_size_is_known(const int dim_size) { +inline bool dim_size_is_known(const dim_t dim_size) { CHECK_GE(dim_size, -1) << "shape dim size must be >= -1, while received " << dim_size; return dim_size != -1; } diff --git a/src/operator/channel_op_common.h b/src/operator/channel_op_common.h index 1afc13ad2594..43f689d2defa 100644 --- a/src/operator/channel_op_common.h +++ b/src/operator/channel_op_common.h @@ -45,6 +45,8 @@ inline void concatenate_helper(const std::vector(out, begin, end), req, input[i]); begin = end; @@ -80,6 +82,8 @@ void split_helper(const mshadow::Tensor &input, size_t size = out.size(); index_t begin = 0; for (size_t i = 0; i < size; ++i) { + // If out[i] is a zero-size tensor, do nothing. + if (out[i].shape_.Size() == 0) continue; index_t end = begin + out[i].size(cdim); Assign(out[i], req[i], slice(input, begin, end)); begin = end; diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index bcad602c95c0..b2e3830064ae 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -278,6 +278,8 @@ inline bool InitStorageType(const nnvm::NodeAttrs& attrs, */ template void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueType val) { + // If b is a zero-size tensor, do nothing. + if (b.Size() == 0) return; if (req != kNullOp) { const size_t size = b.Size(); if (val == 0) { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3d2291870a3c..236dbdbb2639 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7892,6 +7892,7 @@ def test_image_normalize(): check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) +@with_seed() def test_scalar_tensor_creation(): assertRaises(MXNetError, mx.nd.zeros, shape=()) assertRaises(MXNetError, mx.nd.ones, shape=()) @@ -7901,6 +7902,7 @@ def test_scalar_tensor_creation(): assert same(data_mx.asnumpy(), data_np) +@with_seed() def test_zero_size_tensor_creation(): assertRaises(MXNetError, mx.nd.zeros, shape=(0, 1, 3, 0)) assertRaises(MXNetError, mx.nd.ones, shape=(0, 1, 3, 0)) @@ -7910,6 +7912,7 @@ def test_zero_size_tensor_creation(): assert same(data_mx.asnumpy(), data_np) +@with_seed() def test_concat_with_zero_size_tensor(): with mx.enable_np_comp(): data1 = mx.nd.ones((0, 8, 12)) @@ -7918,6 +7921,12 @@ def test_concat_with_zero_size_tensor(): ret = mx.nd.Concat(data1, data2, data3, dim=0) assert ret.shape == (3, 8, 12) + data1 = mx.nd.ones((0, 3, 10)) + data2 = mx.nd.ones((0, 4, 10)) + data3 = mx.nd.ones((0, 5, 10)) + ret = mx.nd.Concat(data1, data2, data3, dim=1) + assert ret.shape == (0, 12, 10) + if __name__ == '__main__': import nose From b4cc4711fe03c238b49347a13b71c3130ed7ebae Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 7 Apr 2019 13:23:17 -0700 Subject: [PATCH 16/18] Avoid letting mkldnn handle zero-size tensors in concat --- src/operator/nn/concat.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 411773b41f94..8fb229889332 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -234,6 +234,8 @@ bool SupportMKLDNNConcat(const std::vector &arrs) { for (auto &arr : arrs) { if (arr.IsView()) return false; if (arr.dtype() != mshadow::kFloat32) return false; + // DO not support zero-size tensors. + if (arr.shape().Size() == 0) return false; int ndim = arr.shape().ndim(); const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims; if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false; From 8663de3c5acae6230c2093984ee2d1e35753de7c Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 7 Apr 2019 14:42:56 -0700 Subject: [PATCH 17/18] Fix quantized_concat infer shape --- src/operator/quantization/quantized_concat.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc index 2cc2ec9d0374..d6aeb41da1f8 100644 --- a/src/operator/quantization/quantized_concat.cc +++ b/src/operator/quantization/quantized_concat.cc @@ -35,23 +35,23 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha CHECK_EQ(out_shape->size(), 3U); mxnet::TShape dshape; index_t size = 0; - bool has_zero = false; + bool has_unknown_dim_size = false; int axis = -1; for (int i = 0; i < param_.num_args; ++i) { mxnet::TShape tmp = (*in_shape)[i]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - has_zero = tmp[axis] == 0 || has_zero; + has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size; size += tmp[axis]; - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } } mxnet::TShape tmp = (*out_shape)[0]; - if (tmp.ndim()) { + if (tmp.ndim() > 0) { axis = CheckAxis(param_.dim, tmp.ndim()); - tmp[axis] = 0; + tmp[axis] = -1; shape_assign(&dshape, tmp); } @@ -62,7 +62,7 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; } - if (!has_zero) dshape[axis] = size; + if (!has_unknown_dim_size) dshape[axis] = size; CHECK(shape_assign(&(*out_shape)[0], dshape)) << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; @@ -71,7 +71,7 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha } SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape{1}); SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape{1}); - return dshape.Size() != 0; + return shape_is_known(dshape); } static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector* in_type, From 9e42544f43cc8d24979d6cfcd1b0298ce5f72f36 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 7 Apr 2019 14:54:04 -0700 Subject: [PATCH 18/18] Try to fix perl c api --- perl-package/AI-MXNetCAPI/mxnet.i | 84 +++++++++++++++---------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index 0e6a05ea9695..0ecf5b3a9cc3 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -641,8 +641,8 @@ int MXNDArrayReshape64(NDArrayHandle handle, * \return 0 when success, -1 when failure happens */ int MXNDArrayGetShape(NDArrayHandle handle, - mx_uint *out_dim, - const mx_uint **out_pdata); + int *out_dim, + const int **out_pdata); /*! * \brief get the content of the data in NDArray * \param handle the handle to the ndarray @@ -1290,20 +1290,20 @@ int MXSymbolGrad(SymbolHandle sym, * \return 0 when success, -1 when failure happens */ int MXSymbolInferShape(SymbolHandle sym, - mx_uint num_args, - const char** in, - const mx_uint *in, - const mx_uint *in, - mx_uint *in_shape_size, - const mx_uint **in_shape_ndim, - const mx_uint ***in_shape_data, - mx_uint *out_shape_size, - const mx_uint **out_shape_ndim, - const mx_uint ***out_shape_data, - mx_uint *aux_shape_size, - const mx_uint **aux_shape_ndim, - const mx_uint ***aux_shape_data, - int *out); + mx_uint num_args, + const char** in, + const mx_uint *in, + const int *in, + mx_uint *in_shape_size, + const int **in_shape_ndim, + const int ***in_shape_data, + mx_uint *out_shape_size, + const int **out_shape_ndim, + const int ***out_shape_data, + mx_uint *aux_shape_size, + const int **aux_shape_ndim, + const int ***aux_shape_data, + int *out); /*! * \brief partially infer shape of unknown input shapes given the known one. * @@ -1332,16 +1332,16 @@ int MXSymbolInferShapePartial(SymbolHandle sym, mx_uint num_args, const char** in, const mx_uint *in, - const mx_uint *in, + const int *in, mx_uint *in_shape_size, - const mx_uint **in_shape_ndim, - const mx_uint ***in_shape_data, + const int **in_shape_ndim, + const int ***in_shape_data, mx_uint *out_shape_size, - const mx_uint **out_shape_ndim, - const mx_uint ***out_shape_data, + const int **out_shape_ndim, + const int ***out_shape_data, mx_uint *aux_shape_size, - const mx_uint **aux_shape_ndim, - const mx_uint ***aux_shape_data, + const int **aux_shape_ndim, + const int ***aux_shape_data, int *out); /*! @@ -1547,7 +1547,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, const char** in, // provided_grad_req_types, const mx_uint num_provided_arg_shapes, const char** in, // provided_arg_shape_names, - const mx_uint* in, // provided_arg_shape_data, + const int* in, // provided_arg_shape_data, const mx_uint* in, // provided_arg_shape_idx, const mx_uint num_provided_arg_dtypes, const char** in, // provided_arg_dtype_names, @@ -1593,24 +1593,24 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, * \return a new executor */ int MXExecutorReshape(int partial_shaping, - int allow_up_sizing, - int dev_type, - int dev_id, - mx_uint num_map_keys, - const char** in, - const int* in, - const int* in, - const mx_uint num_provided_arg_shapes, - const char** in, - const mx_uint* in, - const mx_uint* in, - mx_uint* couple_out_size, - NDArrayHandle** out_first_array, - NDArrayHandle** out_second_array, - mx_uint* out_size, - NDArrayHandle** out_array, - ExecutorHandle shared_exec, - ExecutorHandle *out); + int allow_up_sizing, + int dev_type, + int dev_id, + mx_uint num_map_keys, + const char** in, + const int* in, + const int* in, + const mx_uint num_provided_arg_shapes, + const char** in, + const int* in, + const mx_uint* in, + mx_uint* couple_out_size, + NDArrayHandle** out_first_array, + NDArrayHandle** out_second_array, + mx_uint* out_size, + NDArrayHandle** out_array, + ExecutorHandle shared_exec, + ExecutorHandle *out); /*! * \brief set a call back to notify the completion of operation