-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Zarr checkpoint loses distributed optimizer states due to lack of synchronizers on ranks that create arrays #1053
Comments
Same here. This issue can be pretty serious, and needs to be fixed very soon. |
I confirm this is a bug and your fix looks relevant, thanks 👍 Please note the zarr format is being deprecated and in particular does not play well with the DistributedOptimizer, so I suggest updating the ckpt format to |
Does this problem occur only when Tensor Parallelism (TP) > 1 and Data Parallelism (DP) > 1? Currently, I am using DistributedOptimizer with TP = 1 and DP > 1. Will storing checkpoints in Zarr format cause an issue? |
With TP=1 it might be an issue as well, please use |
It is okay with TP = 1 and DP > 1, in my envs. |
Marking as stale. No activity in 60 days. |
same issue with "--use-distributed-optimizer --ckpt-format torch". |
@Jayoprell this one is not expected, can you elaborate on the symptoms? With |
When using The error info: ================== tensor keys: dict_keys(['param']), dp rank: 0, optim_state:{}, main_param:tensor([0., 0., 0., ..., 0., 0., 0.], ) ================
Traceback (most recent call last):
File "/workspace/Megatron-LM/pretrain_gpt.py", line 264, in <module>
pretrain(
File "/workspace/Megatron-LM/megatron/training/training.py", line 349, in pretrain
iteration, num_floating_point_operations_so_far = train(
File "/workspace/Megatron-LM/megatron/training/training.py", line 1366, in train
save_checkpoint_and_time(iteration, model, optimizer,
File "/workspace/Megatron-LM/megatron/training/training.py", line 1070, in save_checkpoint_and_time
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
File "/workspace/Megatron-LM/megatron/training/checkpointing.py", line 380, in save_checkpoint
optimizer.save_parameter_state(optim_checkpoint_name)
File "/workspace/Megatron-LM/megatron/core/optimizer/distrib_optimizer.py", line 902, in save_parameter_state
state_dict = self.get_parameter_state_dp_zero()
File "/workspace/Megatron-LM/megatron/core/optimizer/distrib_optimizer.py", line 852, in get_parameter_state_dp_zero
tensors[key].detach().cpu()
KeyError: 'exp_avg'
The above tensors is optimizer parameter, and it seems that optimizer state is None. |
Describe the bug
When using a Zarr distributed checkpoint and a distributed optimizer, each rank writes optimizer states according to ShardedTensor's flattened_range. The Zarr strategy uses synchronizers to ensure the correctness of parallel writing. However, synchronizers are not set for ranks that create Zarr arrays. The current implementation only adds synchronizers on ranks that open existing Zarr arrays. Consequently, the writing on the creating ranks may be lost, resulting in all zeros at the corresponding slices in the file.
To Reproduce
run
pretrain_gpt.py
with DP>1, TP>1 and argumentsThen, a toy test inserted after
dist_checkpointing.save
in the following block may not passMegatron-LM/megatron/training/checkpointing.py
Lines 405 to 407 in 86e2927
Add a barrier under this line and using larger DP size may increase the probability of reproducing the failure:
Megatron-LM/megatron/core/dist_checkpointing/strategies/zarr.py
Line 64 in 86e2927
Expected behavior
All tensors should be written to the disk.
Stack trace/logs
If applicable, add the stack trace or logs from the time of the error.
Environment (please complete the following information):
Proposed fix
Set synchronizers when creating Zarr arrays, mirroring the logic used when opening existing Zarr arrays:
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: