-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replicating samples across devices (SP / TP enablement) #597
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some clairfying questions
streaming/base/world.py
Outdated
rank = self.rank // ratio | ||
num_ranks = self.num_ranks // ratio | ||
worker = rank * self.workers_per_rank + self.worker_of_rank | ||
if self.ranks_per_node <= num_ranks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@knighton mind adding some comments to the code here? I'm trying to understand what modifications are happening, and I think there may be a bug or a better way of doing this, but it's pretty hard for me to reason about without completely understanding what you're trying to accomplish here.
iiuc ratio
is the number of consecutive GPUs that should be sharing the same samples -- and should have the same World
information. num_ranks
here is the # of TP blocks, and rank
is the TP block index. I'm not seeing why we need to check if self.ranks_per_node <= num_ranks
, and it seems if we set num_nodes
to 1, then we'll have download duplication across all the nodes which will be pretty bad. The TP ratio should be the degree of duplication of the World
objects and I'm not sure that this does that entirely correctly...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc ratio is the number of consecutive GPUs that should be sharing the same samples -- and should have the same World information. num_ranks here is the # of TP blocks, and rank is the TP block index.
Exactly (for iterating purposes, modulo that bug for coordinating purposes which am about to fix)
I'm not seeing why we need to check if
self.ranks_per_node <= num_ranks
, and it seems if we set num_nodes to 1, then we'll have download duplication across all the nodes which will be pretty bad.
PR originally applied tensor parallelism intra-node only, but then iiuc VC noted TP blocks may be inter-node here: #597 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spelling out the logic for my own benefit:
If num TP blocks maps 1:1 to or exceeds true local world size:
Scale down perceived num nodes to TP blocks / true local world size
Don't need to scale down its perceived local world size because we scaled nodes
else:
Just one perceived node, and what happens in the node, stays in the node
Map ranks to fewer TP blocks by scaling down its perceived local world size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could rewrite it to:
num_nodes = (num_ranks + self.ranks_per_node - 1) // self.ranks_per_node
ranks_per_node = num_ranks // num_nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modulo that bug for coordinating purposes which am about to fix
Now I get an error like:
|
Investigating... |
@andreamad8 how about now? You may need to call |
Ok works now :) I can also test with multiple nodes, but I would need a bit more time (1 hourish) |
We will begrudge you one sidereal hour |
was a bit longer than an hours :) yes, it works I tested up to 16 nodes with different TensorParallel size 1, 4, 8. |
Cleaned up the PR to have naming not TP-specific and fixed a bug that was preventing determinism (both elastic and non-elastic). Added tests as well. Pending regression tests, should be good to go! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
1.
World
World
is just our wrapper aroundtorch.dist
andtorch.utils.data.get_worker_info
.World
just tells you which one you are, out of how many (nodes, ranks/node, workers/rank).World
, semantically speaking.get_worker_info()
will say we are worker 0 of 1. One has to keep that in mind.2.
StreamingDataset
argumentreplication: Optional[int]
replication
iterates the same samples for groups of adjacent GPUs (ranks).