Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def prepare_workloads():
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0)
OpArgMngr.add_workload("rot90", pool["2x2"], 2)


def benchmark_helper(f, *args, **kwargs):
Expand Down
64 changes: 0 additions & 64 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,70 +538,6 @@ def _np_reshape(a, newshape, order='C', out=None):
"""


def _np_roll(a, shift, axis=None):
"""
Roll array elements along a given axis.

Elements that roll beyond the last position are re-introduced at
the first.

Parameters
----------
a : ndarray
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.

Returns
-------
res : ndarray
Output array, with the same shape as `a`.

Notes
-----
Supports rolling over multiple dimensions simultaneously.

Examples
--------
>>> x = np.arange(10)
>>> np.roll(x, 2)
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> np.roll(x, -2)
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])

>>> x2 = np.reshape(x, (2,5))
>>> x2
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
>>> np.roll(x2, 1)
array([[9., 0., 1., 2., 3.],
[4., 5., 6., 7., 8.]])
>>> np.roll(x2, -1)
array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 0.]])
>>> np.roll(x2, 1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, -1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, 1, axis=1)
array([[4., 0., 1., 2., 3.],
[9., 5., 6., 7., 8.]])
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""


def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
"""
Return the sum along diagonals of the array.
Expand Down
70 changes: 68 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
Expand Down Expand Up @@ -6296,6 +6296,72 @@ def less_equal(x1, x2, out=None):
_npi.greater_equal_scalar, out)


@set_module('mxnet.ndarray.numpy')
def roll(a, shift, axis=None):
"""
Roll array elements along a given axis.

Elements that roll beyond the last position are re-introduced at
the first.

Parameters
----------
a : ndarray
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.

Returns
-------
res : ndarray
Output array, with the same shape as `a`.

Notes
-----
Supports rolling over multiple dimensions simultaneously.

Examples
--------
>>> x = np.arange(10)
>>> np.roll(x, 2)
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> np.roll(x, -2)
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])

>>> x2 = np.reshape(x, (2,5))
>>> x2
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
>>> np.roll(x2, 1)
array([[9., 0., 1., 2., 3.],
[4., 5., 6., 7., 8.]])
>>> np.roll(x2, -1)
array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 0.]])
>>> np.roll(x2, 1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, -1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, 1, axis=1)
array([[4., 0., 1., 2., 3.],
[9., 5., 6., 7., 8.]])
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""
return _api_internal.roll(a, shift, axis)


@set_module('mxnet.ndarray.numpy')
def rot90(m, k=1, axes=(0, 1)):
"""
Expand Down Expand Up @@ -6339,7 +6405,7 @@ def rot90(m, k=1, axes=(0, 1)):
[[5., 7.],
[4., 6.]]])
"""
return _npi.rot90(m, k=k, axes=axes)
return _api_internal.rot90(m, k, axes)


@set_module('mxnet.ndarray.numpy')
Expand Down
68 changes: 67 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'pad', 'cumsum']
Expand Down Expand Up @@ -8170,6 +8170,72 @@ def less_equal(x1, x2, out=None):
return _mx_nd_np.less_equal(x1, x2, out)


@set_module('mxnet.numpy')
def roll(a, shift, axis=None):
"""
Roll array elements along a given axis.

Elements that roll beyond the last position are re-introduced at
the first.

Parameters
----------
a : ndarray
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.

Returns
-------
res : ndarray
Output array, with the same shape as `a`.

Notes
-----
Supports rolling over multiple dimensions simultaneously.

Examples
--------
>>> x = np.arange(10)
>>> np.roll(x, 2)
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> np.roll(x, -2)
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])

>>> x2 = np.reshape(x, (2,5))
>>> x2
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
>>> np.roll(x2, 1)
array([[9., 0., 1., 2., 3.],
[4., 5., 6., 7., 8.]])
>>> np.roll(x2, -1)
array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 0.]])
>>> np.roll(x2, 1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, -1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, 1, axis=1)
array([[4., 0., 1., 2., 3.],
[9., 5., 6., 7., 8.]])
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""
return _mx_nd_np.roll(a, shift, axis=axis)


@set_module('mxnet.numpy')
def rot90(m, k=1, axes=(0, 1)):
"""
Expand Down
37 changes: 36 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
Expand Down Expand Up @@ -5841,6 +5841,41 @@ def less_equal(x1, x2, out=None):
_npi.greater_equal_scalar, out)


@set_module('mxnet.symbol.numpy')
def roll(a, shift, axis=None):
"""
Roll array elements along a given axis.

Elements that roll beyond the last position are re-introduced at
the first.

Parameters
----------
a : _Symbol
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.

Returns
-------
res : _Symbol
Output array, with the same shape as `a`.

Notes
-----
Supports rolling over multiple dimensions simultaneously.
"""
return _npi.roll(a, shift, axis=axis)


@set_module('mxnet.symbol.numpy')
def rot90(m, k=1, axes=(0, 1)):
"""
Expand Down
56 changes: 56 additions & 0 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
* \brief Implementation of the API of functions in src/operator/tensor/matrix_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/matrix_op-inl.h"
#include "../../../operator/numpy/np_matrix_op-inl.h"

namespace mxnet {

Expand Down Expand Up @@ -85,4 +87,58 @@ MXNET_REGISTER_API("_npi.split")
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});

MXNET_REGISTER_API("_npi.roll")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_roll");
nnvm::NodeAttrs attrs;
op::NumpyRollParam param;
if (args[1].type_code() == kNull) {
param.shift = dmlc::nullopt;
} else if (args[1].type_code() == kDLInt) {
param.shift = TShape(1, args[1].operator int64_t());
} else {
param.shift = TShape(args[1].operator ObjectRef());
}
if (args[2].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else if (args[2].type_code() == kDLInt) {
param.axis = TShape(1, args[2].operator int64_t());
} else {
param.axis = TShape(args[2].operator ObjectRef());
}
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyRollParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.rot90")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_rot90");
nnvm::NodeAttrs attrs;
op::NumpyRot90Param param;
param.k = args[1].operator int();
if (args[2].type_code() == kNull) {
param.axes = dmlc::nullopt;
} else if (args[2].type_code() == kDLInt) {
param.axes = TShape(1, args[2].operator int64_t());
} else {
param.axes = TShape(args[2].operator ObjectRef());
}
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyRot90Param>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
Loading