Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0c7d128
add bitwise_left/right_shift
barry-jin Sep 16, 2021
15f1768
add more methods
barry-jin Sep 16, 2021
46e6e1f
add mshadow_op.h
barry-jin Sep 16, 2021
27277dc
fix
barry-jin Sep 16, 2021
ab41f89
fix lint & add tests
barry-jin Sep 18, 2021
1305a91
fix
barry-jin Sep 18, 2021
3805946
update operator_tune.cc
barry-jin Sep 20, 2021
9e47620
update amp list
barry-jin Sep 20, 2021
a917329
add rtc functions
barry-jin Sep 21, 2021
8c99fff
fix bitwise rtc functions & numpy op gpu test overriding issue
barry-jin Oct 12, 2021
2385214
Merge remote-tracking branch 'upstream/master' into lr-shift
barry-jin Oct 12, 2021
519e86f
clang-format
barry-jin Oct 12, 2021
a4917fb
fix ci
barry-jin Oct 12, 2021
0ae498c
Merge remote-tracking branch 'upstream/master' into lr-shift
barry-jin Oct 12, 2021
f91bf12
Merge branch 'master' into lr-shift
barry-jin Oct 18, 2021
76dd734
add int16 support
barry-jin Oct 20, 2021
5475246
merge
barry-jin Oct 20, 2021
0832a03
Merge branch 'lr-shift' of https://github.com/barry-jin/incubator-mxn…
barry-jin Oct 20, 2021
49c25db
add MXNET_INT_TYPE_SWITCH_EXT
barry-jin Oct 20, 2021
d923ac1
merge
barry-jin Oct 20, 2021
2e60412
Merge branch 'master' into lr-shift
barry-jin Oct 27, 2021
83e4d66
fix conflict
barry-jin Oct 29, 2021
afcd3ae
merge
barry-jin Oct 29, 2021
c215202
solve conflict
barry-jin Oct 30, 2021
d3f6c07
fix sanity check
barry-jin Oct 30, 2021
955f65f
fix lint
barry-jin Oct 31, 2021
2eeed8a
fix
barry-jin Oct 31, 2021
c930a98
fix lint
barry-jin Oct 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@
'_npi_bitwise_or_scalar',
'_npi_bitwise_xor',
'_npi_bitwise_xor_scalar',
'_npi_bitwise_left_shift',
'_npi_bitwise_left_shift_scalar',
'_npi_bitwise_right_shift',
'_npi_bitwise_right_shift_scalar',
'_npi_rbitwise_left_shift_scalar',
'_npi_rbitwise_right_shift_scalar',
'_npi_blackman',
'_npi_boolean_mask_assign_scalar',
'_npi_boolean_mask_assign_tensor',
Expand Down
80 changes: 79 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'where', 'bincount', 'rollaxis', 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'diag', 'diagonal',
'positive', 'logaddexp', 'floor_divide']
'positive', 'logaddexp', 'floor_divide', 'bitwise_left_shift', 'bitwise_right_shift']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -10015,3 +10015,81 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N
raise ValueError("only where=None or where=True cases are supported for now")
return _api_internal.sum(a, axis, dtype, keepdims, initial, out)
# pylint:enable=redefined-outer-name, too-many-arguments


@set_module('mxnet.ndarray.numpy')
def bitwise_left_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the left. Bits are shifted to the left by
appending x2 0s at the right of x1. Since the internal representation of numbers
is in binary format, this operation is equivalent to ``x1 * 2**x2``

Parameters
----------
x1 : ndarray or scalar
Input values.
x2 : ndarray or scalar
Number of zeros to append to x1. Has to be non-negative. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

Returns
-------
out : ndarray
Result.

Examples
--------
>>> np.binary_repr(5)
'101'
>>> np.left_shift(5, 2)
20
>>> np.binary_repr(20)
'10100'
>>> np.left_shift(5, np.array([1,2,3]))
array([10, 20, 40])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.left_shift(x1, x2, out=out)
return _api_internal.bitwise_left_shift(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
def bitwise_right_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the right. Bits are shifted to the right by
x2. Because the internal representation of numbers is in binary format,
this operation is equivalent to ``x1 / 2**x2``

Parameters
----------
x1 : ndarray or scalar
Input values.
x1 : ndarray or scalar
Number of bits to remove at the right of x1. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

Returns
-------
out : ndarray
Result.

Examples
--------
>>> np.binary_repr(10)
'1010'
>>> np.right_shift(10, 1)
5
>>> np.binary_repr(5)
'101'
>>> np.right_shift(10, np.array([1,2,3]))
array([5, 2, 1])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.right_shift(x1, x2, out=out)
return _api_internal.bitwise_right_shift(x1, x2, out)
104 changes: 103 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal',
'positive', 'logaddexp', 'floor_divide', 'permute_dims']
'positive', 'logaddexp', 'floor_divide', 'permute_dims', 'bitwise_left_shift', 'bitwise_right_shift']

__all__ += fallback.__all__

Expand Down Expand Up @@ -1057,6 +1057,16 @@ def __rxor__(self, other):
"""x.__rxor__(y) <=> y ^ x"""
return bitwise_xor(other, self)

@wrap_mxnp_np_ufunc
def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
return bitwise_left_shift(self, other)

@wrap_mxnp_np_ufunc
def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
return bitwise_right_shift(self, other)

@wrap_mxnp_np_ufunc
def __iand__(self, other):
"""x.__iand__(y) <=> x &= y"""
Expand All @@ -1072,6 +1082,26 @@ def __ixor__(self, other):
"""x.__ixor__(y) <=> x ^= y"""
return bitwise_xor(self, other, out=self)

@wrap_mxnp_np_ufunc
def __ilshift__(self, other):
"""x.__ilshift__(y) <=> x <<= y"""
return bitwise_left_shift(self, other, out=self)

@wrap_mxnp_np_ufunc
def __irshift__(self, other):
"""x.__irshift__(y) <=> x >>= y"""
return bitwise_right_shift(self, other, out=self)

@wrap_mxnp_np_ufunc
def __rlshift__(self, other):
"""x.__rlshift__(y) <=> y << x"""
return bitwise_left_shift(other, self)

@wrap_mxnp_np_ufunc
def __rrshift__(self, other):
"""x.__rrshift__(y) <=> y >> x"""
return bitwise_right_shift(other, self)

def __round__(self, n=0):
"""x.__round__(n)"""
return round(self, decimals=n)
Expand Down Expand Up @@ -13033,3 +13063,75 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N
"""
return _mx_nd_np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
# pylint: enable=redefined-outer-name, too-many-arguments


@set_module('mxnet.numpy')
def bitwise_left_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the left. Bits are shifted to the left by
appending x2 0s at the right of x1. Since the internal representation of numbers
is in binary format, this operation is equivalent to ``x1 * 2**x2``

Parameters
----------
x1 : ndarray or scalar
Input values.
x2 : ndarray or scalar
Number of zeros to append to x1. Has to be non-negative. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

Returns
-------
out : ndarray
Result.

Examples
--------
>>> np.binary_repr(5)
'101'
>>> np.left_shift(5, 2)
20
>>> np.binary_repr(20)
'10100'
"""
return _mx_nd_np.bitwise_left_shift(x1, x2, out)


@set_module('mxnet.numpy')
def bitwise_right_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the right. Bits are shifted to the right by
x2. Because the internal representation of numbers is in binary format,
this operation is equivalent to ``x1 / 2**x2``

Parameters
----------
x1 : ndarray or scalar
Input values.
x1 : ndarray or scalar
Number of bits to remove at the right of x1. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

Returns
-------
out : ndarray
Result.

Examples
--------
>>> np.binary_repr(10)
'1010'
>>> np.right_shift(10, 1)
5
>>> np.binary_repr(5)
'101'
>>> np.right_shift(10, np.array([1,2,3]))
array([5, 2, 1])
"""
return _mx_nd_np.bitwise_right_shift(x1, x2, out)
18 changes: 18 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,22 @@ MXNET_REGISTER_API("_npi.ldexp").set_body([](runtime::MXNetArgs args, runtime::M
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.bitwise_left_shift")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_left_shift");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_left_shift_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rbitwise_left_shift_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.bitwise_right_shift")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_right_shift");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_right_shift_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rbitwise_right_shift_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet
44 changes: 44 additions & 0 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,50 @@ copysign_grad(const DType val,
return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_left_shift_grad(const DType val,
const DType2 val2) {
return op::power(static_cast<DType>(2), val2);
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_left_shift_rgrad(const DType val,
const DType2 val2) {
using type = mixed_type<DType, DType2>;
return val * op::power(static_cast<DType>(2), val2) * op::log(static_cast<type>(2));
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rbitwise_left_shift_grad(const DType val,
const DType2 val2) {
using type = mixed_type<DType, DType2>;
return val2 * op::power(static_cast<DType>(2), val) * op::log(static_cast<type>(2));
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_right_shift_grad(const DType val,
const DType2 val2) {
return op::power(0.5f, val2);
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_right_shift_rgrad(const DType val,
const DType2 val2) {
return val * op::power(0.5f, val2) * op::log(0.5f);
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rbitwise_right_shift_grad(const DType val,
const DType2 val2) {
return val2 * op::power(0.5f, val) * op::log(0.5f);
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
arctan2_grad(const DType val,
Expand Down
32 changes: 32 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,38 @@ __device__ inline mixed_type<DType, DType2> bitwise_and(const DType a,
return real_a & real_b;
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> bitwise_left_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_a << real_b;
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> rbitwise_left_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_b << real_a;
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> bitwise_right_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_a >> real_b;
}

template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> rbitwise_right_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_b >> real_a;
}

DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)

template <typename DType, typename DType2>
Expand Down
2 changes: 1 addition & 1 deletion src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::

ThreadedOpr* threaded_opr = static_cast<OprBlock*>(info->opr_block)->opr;
auto* event_pool = static_cast<CUDAEventPool*>(info->event_pool);
auto [event, event_pool_idx] = event_pool->GetNextEvent();
auto [event, event_pool_idx] = event_pool->GetNextEvent(); // NOLINT(*)
auto ev = event.lock();
MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_));
for (auto* read_var : threaded_opr->const_vars) {
Expand Down
38 changes: 38 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,44 @@ MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>

MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b));

#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif

/*! \brief used for generate element of bitwise_left_shift */
MXNET_BINARY_MATH_OP(bitwise_left_shift, static_cast<int64_t>(a) << static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_left_shift_grad, math::pow(2.0f, static_cast<int64_t>(b)));

MXNET_BINARY_MATH_OP(bitwise_left_shift_rgrad,
static_cast<int64_t>(a) * math::pow(2.0f, static_cast<int64_t>(b)) *
math::log(2.0f));

MXNET_BINARY_MATH_OP(rbitwise_left_shift, static_cast<int64_t>(b) << static_cast<int64_t>(a));

MXNET_BINARY_MATH_OP(rbitwise_left_shift_grad,
static_cast<int64_t>(b) * math::pow(2.0f, static_cast<int64_t>(a)) *
math::log(2.0f));

/*! \brief used for generate element of bitwise_right_shift */
MXNET_BINARY_MATH_OP(bitwise_right_shift, static_cast<int64_t>(a) >> static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_right_shift_grad, math::pow(0.5f, static_cast<int64_t>(b)));

MXNET_BINARY_MATH_OP(bitwise_right_shift_rgrad,
static_cast<int64_t>(a) * math::pow(0.5f, static_cast<int64_t>(b)) *
math::log(0.5f));

MXNET_BINARY_MATH_OP(rbitwise_right_shift, static_cast<int64_t>(b) >> static_cast<int64_t>(a));

MXNET_BINARY_MATH_OP(rbitwise_right_shift_grad,
static_cast<int64_t>(b) * math::pow(0.5f, static_cast<int64_t>(a)) *
math::log(0.5f));

#pragma GCC diagnostic pop

MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));

MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
Expand Down
Loading