Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Problem with gluon.utils.split_data()  #17117

@zburning

Description

@zburning

Description

The current gluon.utils.split_data() has:

step = size // num_slice

# If size < num_slice, make fewer slices
if not even_split and size < num_slice:
        step = 1
        num_slice = size

if batch_axis == 0:
        slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
                  for i in range(num_slice)]

Considering an example:
we have a tensor of shape (31, *), and we want to split it into 8 slices. According to the function, step will be (31 // 8 = 3), so that the tensor will be split into 8 tensors of size [3, 3 ,3 ,3 ,3 ,3, 3, 10], in which the last tensor is excessive large. A better result could be [4, 4, 4, 4, 4, 4, 4, 3]

Maybe we can follow np.array_split()?

Error Message

(Paste the complete error message. Please also include stack trace by setting environment variable DMLC_LOG_STACK_TRACE_DEPTH=10 before running your script.)

To Reproduce

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

Environment

We recommend using our script for collecting the diagnositc information. Run the following command and paste the outputs below:

curl --retry 10 -s https://raw.githubusercontent.com/dmlc/gluon-nlp/master/tools/diagnose.py | python

# paste outputs here

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions