Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better default values for StreamingDataset args #479

Merged
merged 12 commits into from
Oct 27, 2023

Conversation

snarayan21
Copy link
Collaborator

@snarayan21 snarayan21 commented Oct 20, 2023

Description of changes:

Adding better defaults to StreamingDataset. This will help towards making sure StreamingDataset just works for most cases out of the box, giving good quality shuffles without degrading throughput and sacrificing resumption. This depends on the relaxed partitioning algo being merged as well (#476).

Testing:
Multi-stream deterministic resumption tests successfully ran -- see below:

  • Yellow: Full 300 batches on 1 node, NCN = 4.
  • Light Blue: First 100 batches on 1 node, NCN = 4.
  • Green: Middle 100 batches on 2 nodes, NCN = 4.
  • Purple: Last 100 batches on 4 nodes, NCN = 4.
Screenshot 2023-10-22 at 1 04 21 PM

Summary of Changes:

  • predownload defaults to 8x device batch size
  • partition_algo defaults to relaxed
  • num_canonical_nodes defaults to the 64 * number of physical nodes of the run only when using py1s or py2s shuffling. In all other cases, it defaults to the number of physical nodes of the run. relaxed partitioning takes care of resumption cases.
  • shuffle_algo defaults to py1e. This better balances downloads and gives good shuffle quality.
  • shuffle_block_size defaults to 4 million / num_canonical_nodes or 262144 == 1<<18, whichever one is greater. This value was selected based on the results of shuffle quality experiments shown below, where a shuffle strength of 4 million samples (the brown dot) gave comparable train loss stdev and batch composition stdev to stronger shuffles but with less downloads:
    shuffle_strength_vs_batch_composition_std
    shuffle_strength_vs_loss_deviation_std

Defaults will be propagated and changed as appropriate in diffusion and llm-foundry repos in conjunction with this PR and (#476)

Issue #, if available:

Merge Checklist:

Put an x without space in the boxes that apply. If you are unsure about any checklist, please don't hesitate to ask. We are here to help! This is simply a reminder of what we are going to look for before merging your pull request.

General

  • I have read the contributor guidelines
  • This is a documentation change or typo fix. If so, skip the rest of this checklist.
  • I certify that the changes I am introducing will be backward compatible, and I have discussed concerns about this, if any, with the MosaicML team.
  • I have updated any necessary documentation, including README and API docs (if appropriate).

Tests

  • I ran pre-commit on my change. (check out the pre-commit section of prerequisites)
  • I have added tests that prove my fix is effective or that my feature works (if appropriate).
  • I ran the tests locally to make sure it pass. (check out testing)
  • I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes.

streaming/base/batching/per_stream.py Outdated Show resolved Hide resolved
of this size, and samples within each block are shuffled. Defaults to ``1 << 18``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason for 4_000_000? Just wondering

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was based on the shuffle quality experiments, where a shuffle strength of 4_000_000 (i.e., each batch is drawn from 4_000_000 samples) gave good shuffle quality without making the number of downloads required too high.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about max(1 << 18, 1 << 22 // NCN) lol?

Copy link
Collaborator Author

@snarayan21 snarayan21 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works too, it's essentially the same right? Gonna keep as-is for now

streaming/base/dataset.py Show resolved Hide resolved
streaming/base/batching/per_stream.py Outdated Show resolved Hide resolved
streaming/base/batching/random.py Outdated Show resolved Hide resolved
streaming/base/batching/stratified.py Outdated Show resolved Hide resolved
@snarayan21 snarayan21 changed the title Better defaults Better defaults for StreamingDataset Oct 26, 2023
@snarayan21 snarayan21 changed the title Better defaults for StreamingDataset Better default values for StreamingDataset args Oct 26, 2023
Copy link
Collaborator

@karan6181 karan6181 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @knighton Can you take a look at it?

Copy link
Contributor

@knighton knighton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but note lint/test failure

@snarayan21 snarayan21 merged commit 93bf054 into mosaicml:main Oct 27, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants