Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ class TShape : public Tuple<dim_t> {
* \param ndim the number of dimension
* \param value the dimension size for all dims
*/
inline TShape(int ndim, int value = -1) { // NOLINT(*)
inline TShape(const int ndim, const dim_t value) { // NOLINT(*)
this->SetDim(ndim);
if (ndim > 0) {
std::fill_n(begin(), ndim, value);
Expand Down Expand Up @@ -422,12 +422,17 @@ class TShape : public Tuple<dim_t> {
this->swap(s);
}
/*!
* \brief construct the Tuple from content of iterator
* \brief construct the Tuple from content of iterator.
* This function is enforced with template arguments of random access iterator types.
* This is necessary to distinguish from another constructor: TShape(const int, const dim_t).
* \param begin the beginning of iterator
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
template<typename RandomAccessIterator,
typename std::enable_if<
std::is_same<typename std::iterator_traits<RandomAccessIterator>::iterator_category,
std::random_access_iterator_tag>::value, int>::type = 0>
inline TShape(RandomAccessIterator begin,
RandomAccessIterator end) {
this->assign(begin, end);
Expand Down Expand Up @@ -622,7 +627,7 @@ inline bool ndim_is_known(const TShape& x) {
}

/*! brief check if a shape's dim size is known. */
inline bool dim_size_is_known(const int dim_size) {
inline bool dim_size_is_known(const dim_t dim_size) {
CHECK_GE(dim_size, -1) << "shape dim size must be >= -1, while received " << dim_size;
return dim_size != -1;
}
Expand Down
84 changes: 42 additions & 42 deletions perl-package/AI-MXNetCAPI/mxnet.i
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,8 @@ int MXNDArrayReshape64(NDArrayHandle handle,
* \return 0 when success, -1 when failure happens
*/
int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);
int *out_dim,
const int **out_pdata);
/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the ndarray
Expand Down Expand Up @@ -1290,20 +1290,20 @@ int MXSymbolGrad(SymbolHandle sym,
* \return 0 when success, -1 when failure happens
*/
int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** in,
const mx_uint *in,
const mx_uint *in,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *out);
mx_uint num_args,
const char** in,
const mx_uint *in,
const int *in,
mx_uint *in_shape_size,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_uint *out_shape_size,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_uint *aux_shape_size,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *out);
/*!
* \brief partially infer shape of unknown input shapes given the known one.
*
Expand Down Expand Up @@ -1332,16 +1332,16 @@ int MXSymbolInferShapePartial(SymbolHandle sym,
mx_uint num_args,
const char** in,
const mx_uint *in,
const mx_uint *in,
const int *in,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
const int **in_shape_ndim,
const int ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
const int **out_shape_ndim,
const int ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *out);

/*!
Expand Down Expand Up @@ -1547,7 +1547,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const char** in, // provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** in, // provided_arg_shape_names,
const mx_uint* in, // provided_arg_shape_data,
const int* in, // provided_arg_shape_data,
const mx_uint* in, // provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** in, // provided_arg_dtype_names,
Expand Down Expand Up @@ -1593,24 +1593,24 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
* \return a new executor
*/
int MXExecutorReshape(int partial_shaping,
int allow_up_sizing,
int dev_type,
int dev_id,
mx_uint num_map_keys,
const char** in,
const int* in,
const int* in,
const mx_uint num_provided_arg_shapes,
const char** in,
const mx_uint* in,
const mx_uint* in,
mx_uint* couple_out_size,
NDArrayHandle** out_first_array,
NDArrayHandle** out_second_array,
mx_uint* out_size,
NDArrayHandle** out_array,
ExecutorHandle shared_exec,
ExecutorHandle *out);
int allow_up_sizing,
int dev_type,
int dev_id,
mx_uint num_map_keys,
const char** in,
const int* in,
const int* in,
const mx_uint num_provided_arg_shapes,
const char** in,
const int* in,
const mx_uint* in,
mx_uint* couple_out_size,
NDArrayHandle** out_first_array,
NDArrayHandle** out_second_array,
mx_uint* out_size,
NDArrayHandle** out_array,
ExecutorHandle shared_exec,
ExecutorHandle *out);

/*!
* \brief set a call back to notify the completion of operation
Expand Down
3 changes: 1 addition & 2 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@

from .context import Context, current_context, cpu, gpu, cpu_pinned
from . import engine
from .base import MXNetError
from .base import MXNetError, is_np_comp, set_np_comp, enable_np_comp, disable_np_comp
from . import base
from . import numpy
from . import contrib
from . import ndarray
from . import ndarray as nd
Expand Down
135 changes: 117 additions & 18 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def _as_list(obj):
return [obj]


_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']


def _get_op_name_prefix(op_name):
Expand Down Expand Up @@ -607,13 +607,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
# use mx.nd.contrib or mx.sym.contrib from now on
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
contrib_module_old = sys.modules[contrib_module_name_old]
# special handling of registering numpy ops
if module_name == 'ndarray':
numpy_module_name = "%s.numpy" % root_namespace
numpy_module = sys.modules[numpy_module_name]
else:
numpy_module_name = None
numpy_module = None
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
Expand Down Expand Up @@ -652,16 +645,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)
elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
# only register numpy ops under mxnet.numpy in imperative mode
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
# TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
func_name = name[len(op_name_prefix):]
function = make_op_func(hdl, name, func_name)
function.__module__ = numpy_module_name
setattr(numpy_module, function.__name__, function)
numpy_module.__all__.append(function.__name__)


def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
Expand Down Expand Up @@ -751,3 +734,119 @@ def write_all_str(module_file, module_all_list):

ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p


def set_np_comp(flag):
"""
Turns on/off NumPy compatibility. NumPy-compatibility is turned off by default in backend.

Parameters
----------
flag : bool
Indicates whether to turn on/off NumPy compatibility.

Returns
-------
A bool value indicating the previous state of NumPy compatibility.
"""
prev = ctypes.c_int()
check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(flag), ctypes.byref(prev)))
return bool(prev.value)


def is_np_comp():
"""
Checks whether the NumPy compatibility is currently turned on.
NumPy-compatibility is turned off by default in backend.

Returns
-------
A bool value indicating whether the NumPy compatibility is currently on.
"""
curr = ctypes.c_bool()
check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
return curr.value


class _NumpyCompatibilityStateScope(object):
"""Scope for managing numpy compatibility state.

Example::

with _NumpyCompatibilityStateScope(True):
y = model(x)
backward([y])

"""
def __init__(self, is_np_comp): #pylint: disable=redefined-outer-name
self._enter_is_np_comp = is_np_comp
self._prev_is_np_comp = None

def __enter__(self):
if self._enter_is_np_comp is not None:
self._prev_is_np_comp = set_np_comp(self._enter_is_np_comp)

def __exit__(self, ptype, value, trace):
if self._enter_is_np_comp is not None and self._prev_is_np_comp != self._enter_is_np_comp:
set_np_comp(self._prev_is_np_comp)


def enable_np_comp():
"""Returns a NumPy compatibility state scope to be used in 'with' statement
and captures code that needs the compatibility.

Example::

with mx.enable_np_comp():
# A scalar tensor's shape is `()`, whose `ndim` is `0`.
scalar = mx.nd.ones(shape=())
assert scalar.shape == ()

# In NumPy compatible mode, 0 in a shape means that dimension contains zero elements.
data = mx.sym.var("data", shape=(0, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape()
assert arg_shapes[0] == (0, 2, 3)
assert out_shapes[0] == (0, 2, 3)

# -1 means unknown shape dimension size in the new NumPy-compatible shape definition
data = mx.sym.var("data", shape=(-1, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == (-1, 2, 3)
assert out_shapes[0] == (-1, 2, 3)

# When a shape is completely unknown in NumPy-compatible mode, it is
# represented as `None` in Python.
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] is None
assert out_shapes[0] is None
"""
return _NumpyCompatibilityStateScope(True)


def disable_np_comp():
"""Returns a state scope with NumPy-compatibility disabled to be used in 'with' statement
and captures code that does not need the compatibility.

Example::

with mx.disable_np_comp():
# 0 means unknown shape dimension size in the legacy shape definition.
data = mx.sym.var("data", shape=(0, 2, 3))
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == (0, 2, 3)
assert out_shapes[0] == (0, 2, 3)

# When a shape is completely unknown in the legacy mode (default), its ndim is
# equal to 0 and it is represented as `()` in Python.
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == ()
assert out_shapes[0] == ()
"""
return _NumpyCompatibilityStateScope(False)
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

"""NDArray API of MXNet."""

from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy
from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray
# pylint: disable=wildcard-import, redefined-builtin
try:
from .gen_op import * # pylint: disable=unused-wildcard-import
Expand Down
18 changes: 0 additions & 18 deletions python/mxnet/ndarray/numpy.py

This file was deleted.

Loading