From 113435e5aa6d42dbd32dffbe975625b14ccb3dfc Mon Sep 17 00:00:00 2001 From: Chiyuan Zhang Date: Tue, 6 Oct 2015 00:29:10 -0400 Subject: [PATCH] __neg__ should return new array instead of changing self --- python/mxnet/ndarray.py | 2 +- tests/python/unittest/test_ndarray.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index b4e328d67618..b74d93de3768 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -115,7 +115,7 @@ def __mul__(self, other): raise TypeError('type %s not supported' % str(type(other))) def __neg__(self): - return NDArray._mul_scalar(self, -1.0, out=self) + return NDArray._mul_scalar(self, -1.0) def __imul__(self, other): if isinstance(other, NDArray): diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8a3761dfe013..36a1672bc636 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -49,6 +49,18 @@ def test_ndarray_elementwise(): check_with_uniform(lambda x, y: x * y, 2, dim) check_with_uniform(lambda x, y: x / y, 2, dim) +def test_ndarray_negate(): + npy = np.random.uniform(-10, 10, (2,3,4)) + arr = mx.nd.array(npy) + assert reldiff(npy, arr.asnumpy()) < 1e-6 + assert reldiff(-npy, (-arr).asnumpy()) < 1e-6 + + # a final check to make sure the negation (-) is not implemented + # as inplace operation, so the contents of arr does not change after + # we compute (-arr) + assert reldiff(npy, arr.asnumpy()) < 1e-6 + + def test_ndarray_copy(): c = mx.nd.array(np.random.uniform(-10, 10, (10, 10))) d = c.copyto(mx.Context('cpu', 0)) @@ -67,7 +79,7 @@ def test_ndarray_scalar(): c[:] = 2 assert(np.sum(c.asnumpy()) - 200 < 1e-5) d = -c + 2 - assert(np.sum(c.asnumpy()) < 1e-5) + assert(np.sum(d.asnumpy()) < 1e-5) def test_ndarray_pickle(): np.random.seed(0) @@ -142,6 +154,7 @@ def test_dot(): test_ndarray_saveload() test_ndarray_copy() test_ndarray_elementwise() + test_ndarray_negate() test_ndarray_scalar() test_clip() test_dot()