Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replicating samples across devices (SP / TP enablement) #597

Merged
merged 13 commits into from
Feb 22, 2024
Merged

Conversation

knighton
Copy link
Contributor

@knighton knighton commented Feb 9, 2024

1. World

  • Streaming's model of the world/the run/the job is (num nodes, ranks per node) rank processes + (num nodes, ranks per node, workers per rank) worker processes.
  • For example, 8 ranks, 8 workers/rank -> 8 + 8 * 8 = 72 StreamingDataset replicas per node.
  • Currently they coordinate via filelocks and shared memory.
  • A World is just our wrapper around torch.dist and torch.utils.data.get_worker_info.
  • The World just tells you which one you are, out of how many (nodes, ranks/node, workers/rank).
  • There's always only one World, semantically speaking.
  • However, when in a rank, get_worker_info() will say we are worker 0 of 1. One has to keep that in mind.

2. StreamingDataset argument replication: Optional[int]

  • replication iterates the same samples for groups of adjacent GPUs (ranks).
  • It does this by scaling down (in number) and repeating ranks.
  • But the underlying data doesn't change, so we are correspondingly scaling up the ranks (in length).
  • It prefers to merge ranks within the same node first, then if it needs to merge further, across node boundaries.

streaming/base/dataset.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@snarayan21 snarayan21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some clairfying questions

rank = self.rank // ratio
num_ranks = self.num_ranks // ratio
worker = rank * self.workers_per_rank + self.worker_of_rank
if self.ranks_per_node <= num_ranks:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knighton mind adding some comments to the code here? I'm trying to understand what modifications are happening, and I think there may be a bug or a better way of doing this, but it's pretty hard for me to reason about without completely understanding what you're trying to accomplish here.

iiuc ratio is the number of consecutive GPUs that should be sharing the same samples -- and should have the same World information. num_ranks here is the # of TP blocks, and rank is the TP block index. I'm not seeing why we need to check if self.ranks_per_node <= num_ranks, and it seems if we set num_nodes to 1, then we'll have download duplication across all the nodes which will be pretty bad. The TP ratio should be the degree of duplication of the World objects and I'm not sure that this does that entirely correctly...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc ratio is the number of consecutive GPUs that should be sharing the same samples -- and should have the same World information. num_ranks here is the # of TP blocks, and rank is the TP block index.

Exactly (for iterating purposes, modulo that bug for coordinating purposes which am about to fix)

I'm not seeing why we need to check if self.ranks_per_node <= num_ranks, and it seems if we set num_nodes to 1, then we'll have download duplication across all the nodes which will be pretty bad.

PR originally applied tensor parallelism intra-node only, but then iiuc VC noted TP blocks may be inter-node here: #597 (comment)

Copy link
Contributor Author

@knighton knighton Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling out the logic for my own benefit:

If num TP blocks maps 1:1 to or exceeds true local world size:
    Scale down perceived num nodes to TP blocks / true local world size
    Don't need to scale down its perceived local world size because we scaled nodes
else:
    Just one perceived node, and what happens in the node, stays in the node
    Map ranks to fewer TP blocks by scaling down its perceived local world size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could rewrite it to:

num_nodes = (num_ranks + self.ranks_per_node - 1) // self.ranks_per_node
ranks_per_node = num_ranks // num_nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modulo that bug for coordinating purposes which am about to fix

a17419e

streaming/base/world.py Outdated Show resolved Hide resolved
@andreamad8
Copy link

Now I get an error like:

  File "/streaming/streaming/base/dataset.py", line 1458, in __iter__
    sample_ids = self._get_work(world, epoch, sample_in_epoch)
  File "/streaming/streaming/base/dataset.py", line 1002, in _get_work
    shape_shm, data_shm = self._share_work(epoch_sample_ids)
  File "/streaming/streaming/base/dataset.py", line 949, in _share_work
    shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
  File "/streaming/streaming/base/shared/memory.py", line 41, in __init__
    shm = BuiltinSharedMemory(name, create, size)
  File "lib/python3.9/multiprocessing/shared_memory.py", line 103, in __init__
    self._fd = _posixshmem.shm_open(
FileExistsError: [Errno 17] File exists: '/000000_epoch_shape'

@knighton
Copy link
Contributor Author

knighton commented Feb 9, 2024

Now I get an error like:

  File "/streaming/streaming/base/dataset.py", line 1458, in __iter__
    sample_ids = self._get_work(world, epoch, sample_in_epoch)
  File "/streaming/streaming/base/dataset.py", line 1002, in _get_work
    shape_shm, data_shm = self._share_work(epoch_sample_ids)
  File "/streaming/streaming/base/dataset.py", line 949, in _share_work
    shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
  File "/streaming/streaming/base/shared/memory.py", line 41, in __init__
    shm = BuiltinSharedMemory(name, create, size)
  File "lib/python3.9/multiprocessing/shared_memory.py", line 103, in __init__
    self._fd = _posixshmem.shm_open(
FileExistsError: [Errno 17] File exists: '/000000_epoch_shape'

Investigating...

@knighton
Copy link
Contributor Author

knighton commented Feb 9, 2024

@andreamad8 how about now? You may need to call streaming.base.util.clean_stale_shared_memory()

@andreamad8
Copy link

andreamad8 commented Feb 9, 2024

Ok works now :)

I can also test with multiple nodes, but I would need a bit more time (1 hourish)

@knighton
Copy link
Contributor Author

knighton commented Feb 9, 2024

We will begrudge you one sidereal hour

@andreamad8
Copy link

was a bit longer than an hours :)

yes, it works I tested up to 16 nodes with different TensorParallel size 1, 4, 8.

@snarayan21
Copy link
Collaborator

Cleaned up the PR to have naming not TP-specific and fixed a bug that was preventing determinism (both elastic and non-elastic). Added tests as well. Pending regression tests, should be good to go!

@snarayan21 snarayan21 changed the title Custom world and/or tensor parallel Replicating samples across devices (sequence / tensor parallelism enablement) Feb 22, 2024
@snarayan21 snarayan21 changed the title Replicating samples across devices (sequence / tensor parallelism enablement) Replicating samples across devices (SP / TP enablement) Feb 22, 2024
Copy link
Collaborator

@snarayan21 snarayan21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@snarayan21 snarayan21 merged commit 0bb81e8 into main Feb 22, 2024
6 checks passed
@snarayan21 snarayan21 deleted the james/wrd branch February 22, 2024 19:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants