Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Problem with gluon.utils.split_data() #17117

Closed
zburning opened this issue Dec 19, 2019 · 5 comments · Fixed by #17123
Closed

Problem with gluon.utils.split_data() #17117

zburning opened this issue Dec 19, 2019 · 5 comments · Fixed by #17123

Comments

@zburning
Copy link
Contributor

Description

The current gluon.utils.split_data() has:

step = size // num_slice

# If size < num_slice, make fewer slices
if not even_split and size < num_slice:
        step = 1
        num_slice = size

if batch_axis == 0:
        slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
                  for i in range(num_slice)]

Considering an example:
we have a tensor of shape (31, *), and we want to split it into 8 slices. According to the function, step will be (31 // 8 = 3), so that the tensor will be split into 8 tensors of size [3, 3 ,3 ,3 ,3 ,3, 3, 10], in which the last tensor is excessive large. A better result could be [4, 4, 4, 4, 4, 4, 4, 3]

Maybe we can follow np.array_split()?

Error Message

(Paste the complete error message. Please also include stack trace by setting environment variable DMLC_LOG_STACK_TRACE_DEPTH=10 before running your script.)

To Reproduce

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

Environment

We recommend using our script for collecting the diagnositc information. Run the following command and paste the outputs below:

curl --retry 10 -s https://raw.githubusercontent.com/dmlc/gluon-nlp/master/tools/diagnose.py | python

# paste outputs here
@zburning zburning added the Bug label Dec 19, 2019
@wkcn wkcn added the Gluon label Dec 19, 2019
@wkcn
Copy link
Member

wkcn commented Dec 19, 2019

slice_len = length // num_slice
rest = length % num_slice
start = slice_len * index + min(index, rest)
end = start + slice_len + (index < rest)

@zburning
Copy link
Contributor Author

Thank you, this is a clean solution.

@leezu
Copy link
Contributor

leezu commented Dec 19, 2019

Following np.array_split is a good idea. It should have been done from the beginning. Would you like to create a PR?

@zburning
Copy link
Contributor Author

@leezu Yes

@sxjscience
Copy link
Member

@leezu @zburning How about labeling it as a performance issue?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants