From fa3da80c3f304cf363a33df32a6e5e0f194d420f Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Thu, 9 Jul 2020 21:09:55 -0700 Subject: [PATCH 1/3] Fix mx.symbol.numpy._Symbol.__deepcopy__ logic error Performed shallow copy instead of deep copy --- python/mxnet/symbol/numpy/_symbol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index d3521cad1274..9b193f850a93 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -283,7 +283,7 @@ def __neg__(self): return negative(self) def __deepcopy__(self, _): - return super(_Symbol, self).as_np_ndarray() + return super().__deepcopy__(_).as_np_ndarray() def __eq__(self, other): """x.__eq__(y) <=> x == y""" From 06a8252e087dffba7f4b643f212673b9b89b54cd Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 10 Jul 2020 21:37:54 +0000 Subject: [PATCH 2/3] Test --- tests/python/unittest/test_symbol.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 1c84af0c668e..a0ddc8a127e0 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -479,3 +479,13 @@ def test_infershape_happens_for_all_ops_in_graph(): assert False +def test_symbol_copy(): + a = mx.sym.Variable('a') + b = copy.copy(a) + b._set_attr(name='b') + assert a.name == 'a' and b.name == 'b' + + a = mx.sym.numpy._Variable('a') + b = copy.copy(a) + b._set_attr(name='b') + assert a.name == 'a' and b.name == 'b' From dc3ce0888721f44b2c5a9a0b065b4b640cda6012 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 15 Jul 2020 17:28:43 +0000 Subject: [PATCH 3/3] Fix test --- tests/python/unittest/test_symbol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index a0ddc8a127e0..910b6ca15499 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -485,7 +485,7 @@ def test_symbol_copy(): b._set_attr(name='b') assert a.name == 'a' and b.name == 'b' - a = mx.sym.numpy._Variable('a') + a = mx.sym.Variable('a').as_np_ndarray() b = copy.copy(a) b._set_attr(name='b') assert a.name == 'a' and b.name == 'b'