From 8366406f5ef8695910ba16e5652b615b77474550 Mon Sep 17 00:00:00 2001 From: Wang Date: Thu, 19 Dec 2019 17:39:27 +0800 Subject: [PATCH 1/6] fix --- python/mxnet/gluon/utils.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 81a8dbaa486b..a5322edf20f3 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -70,30 +70,19 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): "uneven partitioning of data."%( str(data.shape), num_slice, batch_axis, num_slice)) - step = size // num_slice - - # If size < num_slice, make fewer slices - if not even_split and size < num_slice: - step = 1 - num_slice = size - - if batch_axis == 0: - slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] - for i in range(num_slice)] - elif even_split: - if is_np_array(): - slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis) - else: - slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) + Neach_section, extras = divmod(size, num_slice) + section_sizes = [0] + (extras * [Neach_section + 1] + + (num_slice - extras) * [Neach_section]) + div_points = np.array(section_sizes).cumsum() + if is_np_array(): + slices = _mx_np.split(data, indices_or_sections=div_points, axis=batch_axis) else: - if is_np_array(): - indices = [step * i for i in range(1, num_slice)] - slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis) - else: - slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) - if i < num_slice - 1 else - ndarray.slice_axis(data, batch_axis, i*step, size) - for i in range(num_slice)] + slices = [] + sary = ndarray.swapaxes(data, batch_axis, 0) + for i in range(num_slice): + st = div_points[i] + end = div_points[i + 1] + slices.append(ndarray.swapaxes(sary[st: end], batch_axis, 0)) return slices From a61ef006024202a0e27622c303634bfebaeed9ef Mon Sep 17 00:00:00 2001 From: Wang Date: Fri, 20 Dec 2019 22:44:47 +0800 Subject: [PATCH 2/6] fix & add test --- python/mxnet/gluon/utils.py | 9 ++++----- tests/python/unittest/test_gluon.py | 7 +++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index a5322edf20f3..374fad6a6c35 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -70,19 +70,18 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): "uneven partitioning of data."%( str(data.shape), num_slice, batch_axis, num_slice)) - Neach_section, extras = divmod(size, num_slice) - section_sizes = [0] + (extras * [Neach_section + 1] + - (num_slice - extras) * [Neach_section]) + n_each_section, extras = divmod(size, num_slice) + section_sizes = [0] + (extras * [n_each_section + 1] + + (num_slice - extras) * [n_each_section]) div_points = np.array(section_sizes).cumsum() if is_np_array(): slices = _mx_np.split(data, indices_or_sections=div_points, axis=batch_axis) else: slices = [] - sary = ndarray.swapaxes(data, batch_axis, 0) for i in range(num_slice): st = div_points[i] end = div_points[i + 1] - slices.append(ndarray.swapaxes(sary[st: end], batch_axis, 0)) + slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end)) return slices diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d27c241f9c99..7dc1522d04d2 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -958,11 +958,14 @@ def check_split_data(x, num_slice, batch_axis, **kwargs): mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(), x.asnumpy()) + np_res = np.array_split(x.asnumpy(), num_slice, axis=batch_axis) + res_asnp = [s.asnumpy() for s in res] + for r1, r2 in zip(np_res, res_asnp): + assert all(r1.reshape(-1) == r2.reshape(-1)) + -@with_seed() def test_split_data(): x = mx.nd.random.uniform(shape=(128, 33, 64)) - check_split_data(x, 8, 0) check_split_data(x, 3, 1) check_split_data(x, 4, 1, even_split=False) From d55c6d5e20c863edd5bae4894fe64f9b2443f9e0 Mon Sep 17 00:00:00 2001 From: Wang Date: Sat, 21 Dec 2019 00:20:19 +0800 Subject: [PATCH 3/6] add mx.numpy test --- python/mxnet/gluon/utils.py | 2 +- tests/python/unittest/test_gluon.py | 23 +++++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 374fad6a6c35..83ed15aed450 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -75,7 +75,7 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): (num_slice - extras) * [n_each_section]) div_points = np.array(section_sizes).cumsum() if is_np_array(): - slices = _mx_np.split(data, indices_or_sections=div_points, axis=batch_axis) + slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis) else: slices = [] for i in range(num_slice): diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 7dc1522d04d2..d7e43b169b75 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -23,7 +23,10 @@ from mxnet.gluon import nn from mxnet.base import py_str from mxnet.test_utils import assert_almost_equal +from mxnet.util import is_np_array from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID +from mxnet.test_utils import use_np +import mxnet.numpy as _mx_np from common import (setup_module, with_seed, assertRaises, teardown, assert_raises_cudnn_not_satisfied) import numpy as np @@ -964,8 +967,24 @@ def check_split_data(x, num_slice, batch_axis, **kwargs): assert all(r1.reshape(-1) == r2.reshape(-1)) -def test_split_data(): - x = mx.nd.random.uniform(shape=(128, 33, 64)) +def check_split_data(x, num_slice, batch_axis, **kwargs): + res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs) + assert len(res) == num_slice + if not is_np_array(): + mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(), + x.asnumpy()) + else: + mx.test_utils.assert_almost_equal(_mx_np.concatenate(res, axis=batch_axis).asnumpy(), + x.asnumpy()) + np_res = np.array_split(x.asnumpy(), num_slice, axis=batch_axis) + res_asnp = [s.asnumpy() for s in res] + for r1, r2 in zip(np_res, res_asnp): + assert all(r1.reshape(-1) == r2.reshape(-1)) + + +@use_np +def check_split_data_np(): + x = _mx_np.random.uniform(size=(128, 33, 64)) check_split_data(x, 8, 0) check_split_data(x, 3, 1) check_split_data(x, 4, 1, even_split=False) From a1eab552ee6d49ae033479868cbe8898768d73fe Mon Sep 17 00:00:00 2001 From: Wang Date: Sat, 21 Dec 2019 13:38:30 +0800 Subject: [PATCH 4/6] fix name --- tests/python/unittest/test_gluon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d7e43b169b75..394478b919d4 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -33,7 +33,7 @@ from numpy.testing import assert_array_equal from nose.tools import raises, assert_raises from copy import deepcopy -import warnings +import warnings∂ import json import unittest import random @@ -967,7 +967,7 @@ def check_split_data(x, num_slice, batch_axis, **kwargs): assert all(r1.reshape(-1) == r2.reshape(-1)) -def check_split_data(x, num_slice, batch_axis, **kwargs): +def test_split_data(x, num_slice, batch_axis, **kwargs): res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs) assert len(res) == num_slice if not is_np_array(): @@ -983,7 +983,7 @@ def check_split_data(x, num_slice, batch_axis, **kwargs): @use_np -def check_split_data_np(): +def test_split_data_np(): x = _mx_np.random.uniform(size=(128, 33, 64)) check_split_data(x, 8, 0) check_split_data(x, 3, 1) From daaea5cfc7c6c4bb5a7e1a26c4536bbe785b8e80 Mon Sep 17 00:00:00 2001 From: Wang Date: Sat, 21 Dec 2019 23:14:21 +0800 Subject: [PATCH 5/6] fix mis input --- tests/python/unittest/test_gluon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 394478b919d4..989ffe2183ab 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -33,7 +33,7 @@ from numpy.testing import assert_array_equal from nose.tools import raises, assert_raises from copy import deepcopy -import warnings∂ +import warnings import json import unittest import random From 9a23b5723a0c1af24453b6cb081d53f673864286 Mon Sep 17 00:00:00 2001 From: Wang Date: Mon, 30 Dec 2019 16:56:33 +0800 Subject: [PATCH 6/6] fix test --- tests/python/unittest/test_gluon.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 989ffe2183ab..37b9cd7a0697 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -955,19 +955,8 @@ def test_deferred_init(): layer(x) -def check_split_data(x, num_slice, batch_axis, **kwargs): - res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs) - assert len(res) == num_slice - mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(), - x.asnumpy()) - np_res = np.array_split(x.asnumpy(), num_slice, axis=batch_axis) - res_asnp = [s.asnumpy() for s in res] - for r1, r2 in zip(np_res, res_asnp): - assert all(r1.reshape(-1) == r2.reshape(-1)) - - -def test_split_data(x, num_slice, batch_axis, **kwargs): +def check_split_data(x, num_slice, batch_axis, **kwargs): res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs) assert len(res) == num_slice if not is_np_array(): @@ -982,6 +971,7 @@ def test_split_data(x, num_slice, batch_axis, **kwargs): assert all(r1.reshape(-1) == r2.reshape(-1)) +@with_seed() @use_np def test_split_data_np(): x = _mx_np.random.uniform(size=(128, 33, 64)) @@ -995,6 +985,18 @@ def test_split_data_np(): return assert False, "Should have failed" +@with_seed() +def test_split_data(): + x = mx.nd.random.uniform(shape=(128, 33, 64)) + check_split_data(x, 8, 0) + check_split_data(x, 3, 1) + check_split_data(x, 4, 1, even_split=False) + check_split_data(x, 15, 1, even_split=False) + try: + check_split_data(x, 4, 1) + except ValueError: + return + assert False, "Should have failed" @with_seed() def test_flatten():