diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 81a8dbaa486b..83ed15aed450 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -70,30 +70,18 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): "uneven partitioning of data."%( str(data.shape), num_slice, batch_axis, num_slice)) - step = size // num_slice - - # If size < num_slice, make fewer slices - if not even_split and size < num_slice: - step = 1 - num_slice = size - - if batch_axis == 0: - slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] - for i in range(num_slice)] - elif even_split: - if is_np_array(): - slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis) - else: - slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) + n_each_section, extras = divmod(size, num_slice) + section_sizes = [0] + (extras * [n_each_section + 1] + + (num_slice - extras) * [n_each_section]) + div_points = np.array(section_sizes).cumsum() + if is_np_array(): + slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis) else: - if is_np_array(): - indices = [step * i for i in range(1, num_slice)] - slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis) - else: - slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) - if i < num_slice - 1 else - ndarray.slice_axis(data, batch_axis, i*step, size) - for i in range(num_slice)] + slices = [] + for i in range(num_slice): + st = div_points[i] + end = div_points[i + 1] + slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end)) return slices diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d27c241f9c99..37b9cd7a0697 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -23,7 +23,10 @@ from mxnet.gluon import nn from mxnet.base import py_str from mxnet.test_utils import assert_almost_equal +from mxnet.util import is_np_array from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID +from mxnet.test_utils import use_np +import mxnet.numpy as _mx_np from common import (setup_module, with_seed, assertRaises, teardown, assert_raises_cudnn_not_satisfied) import numpy as np @@ -952,17 +955,39 @@ def test_deferred_init(): layer(x) + def check_split_data(x, num_slice, batch_axis, **kwargs): res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs) assert len(res) == num_slice - mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(), - x.asnumpy()) + if not is_np_array(): + mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(), + x.asnumpy()) + else: + mx.test_utils.assert_almost_equal(_mx_np.concatenate(res, axis=batch_axis).asnumpy(), + x.asnumpy()) + np_res = np.array_split(x.asnumpy(), num_slice, axis=batch_axis) + res_asnp = [s.asnumpy() for s in res] + for r1, r2 in zip(np_res, res_asnp): + assert all(r1.reshape(-1) == r2.reshape(-1)) +@with_seed() +@use_np +def test_split_data_np(): + x = _mx_np.random.uniform(size=(128, 33, 64)) + check_split_data(x, 8, 0) + check_split_data(x, 3, 1) + check_split_data(x, 4, 1, even_split=False) + check_split_data(x, 15, 1, even_split=False) + try: + check_split_data(x, 4, 1) + except ValueError: + return + assert False, "Should have failed" + @with_seed() def test_split_data(): x = mx.nd.random.uniform(shape=(128, 33, 64)) - check_split_data(x, 8, 0) check_split_data(x, 3, 1) check_split_data(x, 4, 1, even_split=False) @@ -973,7 +998,6 @@ def test_split_data(): return assert False, "Should have failed" - @with_seed() def test_flatten(): flatten = nn.Flatten()