Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Zarr checkpoint loses distributed optimizer states due to lack of synchronizers on ranks that create arrays #1053

Open
LouChao98 opened this issue Aug 30, 2024 · 9 comments
Assignees

Comments

@LouChao98
Copy link

Describe the bug

When using a Zarr distributed checkpoint and a distributed optimizer, each rank writes optimizer states according to ShardedTensor's flattened_range. The Zarr strategy uses synchronizers to ensure the correctness of parallel writing. However, synchronizers are not set for ranks that create Zarr arrays. The current implementation only adds synchronizers on ranks that open existing Zarr arrays. Consequently, the writing on the creating ranks may be lost, resulting in all zeros at the corresponding slices in the file.

To Reproduce

run pretrain_gpt.py with DP>1, TP>1 and arguments

--save {SOME_WHERE} \
--save-interval 10 \
--use-dist-ckpt \
--dist-ckpt-format zarr \
--use-distributed-optimizer \
--weight-decay 0.1 # ensure nonzero grads everywhere and make test easy

Then, a toy test inserted after dist_checkpointing.save in the following block may not pass

async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
async_sharded_save=args.async_save,
validate_access_integrity=validate_sharding_integrity)

# load saved ckpt and compare

state_dict_for_load = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state,
                        use_dist_ckpt, iteration, optim_sd_kwargs=optim_sd_kwargs)
load_strategy = get_default_load_sharded_strategy(checkpoint_name)
state_dict_loaded = dist_checkpointing.load(state_dict_for_load, checkpoint_name, load_strategy)

opt_state_correct = {key: deepcopy(value) for key, value in optimizer.optimizer.state.items()}
optimizer.load_state_dict(state_dict_loaded['optimizer'])
opt_state_loaded = optimizer.optimizer.state

for oi, opt in state_dict_loaded['optimizer']['param_state'].items():
    for key in ['exp_avg', 'exp_avg_sq']:
        if isinstance(state_dict['optimizer']['param_state'][oi][key], list):
            param_key = ';'.join([item.key for item in state_dict['optimizer']['param_state'][oi][key]])
        else:
            param_key = state_dict['optimizer']['param_state'][oi][key].key
        # after a few steps of training, it is unlikely to have zero-valued tensors in models and optimizers. So a all zero tensor indicates an error.
        assert opt[key].abs().sum() != 0, param_key  

Add a barrier under this line and using larger DP size may increase the probability of reproducing the failure:

arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)

Expected behavior

All tensors should be written to the disk.

Stack trace/logs
If applicable, add the stack trace or logs from the time of the error.

Environment (please complete the following information):

  • Megatron-LM commit ID 9bcd417
  • PyTorch version 2.3.0a0+ebedce2
  • CUDA version 12.3
  • NCCL version 2.20.3

Proposed fix

Set synchronizers when creating Zarr arrays, mirroring the logic used when opening existing Zarr arrays:

        arr = zarr.create(
            sharded_tensor.global_shape,
            dtype=np_dtype,
            store=checkpoint_dir / sharded_tensor.key,
            chunks=sharded_tensor.max_allowed_chunks(),
            compressor=None,
            fill_value=None,
            write_empty_chunks=True,
            synchronizer=(  # add this
                zarr.ProcessSynchronizer(str(checkpoint_dir / f'{sharded_tensor.key}.sync'))   # add this
                if sharded_tensor.flattened_range is not None   # add this
                else None   # add this
            ),
        )

Additional context
Add any other context about the problem here.

@TissueC
Copy link

TissueC commented Aug 30, 2024

Same here. This issue can be pretty serious, and needs to be fixed very soon.

@mikolajblaz
Copy link
Contributor

I confirm this is a bug and your fix looks relevant, thanks 👍

Please note the zarr format is being deprecated and in particular does not play well with the DistributedOptimizer, so I suggest updating the ckpt format to --dist-ckpt-format torch_dist.

@santha96
Copy link

santha96 commented Sep 3, 2024

Does this problem occur only when Tensor Parallelism (TP) > 1 and Data Parallelism (DP) > 1? Currently, I am using DistributedOptimizer with TP = 1 and DP > 1. Will storing checkpoints in Zarr format cause an issue?

@mikolajblaz
Copy link
Contributor

Does this problem occur only when Tensor Parallelism (TP) > 1 and Data Parallelism (DP) > 1? Currently, I am using DistributedOptimizer with TP = 1 and DP > 1. Will storing checkpoints in Zarr format cause an issue?

With TP=1 it might be an issue as well, please use --dist-ckpt-format torch_dist.

@TissueC
Copy link

TissueC commented Sep 3, 2024

Does this problem occur only when Tensor Parallelism (TP) > 1 and Data Parallelism (DP) > 1? Currently, I am using DistributedOptimizer with TP = 1 and DP > 1. Will storing checkpoints in Zarr format cause an issue?

It is okay with TP = 1 and DP > 1, in my envs.

Copy link

github-actions bot commented Nov 4, 2024

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Nov 4, 2024
@Jayoprell
Copy link

same issue with "--use-distributed-optimizer --ckpt-format torch".
Megatron: core_r0.9.0, 1afee59

@mikolajblaz
Copy link
Contributor

same issue with "--use-distributed-optimizer --ckpt-format torch". Megatron: core_r0.9.0, 1afee59

@Jayoprell this one is not expected, can you elaborate on the symptoms? With --ckpt-format torch we don't do any file-based synchronization.

@Jayoprell
Copy link

Jayoprell commented Nov 7, 2024

When using use-distributed-optimizer with ckpt-format torch, save_checkpoint has only dp rank=0 gathers all the others optimizer status and save to file. So, it should not have any sync problem?

The error info:

================== tensor keys: dict_keys(['param']), dp rank: 0, optim_state:{}, main_param:tensor([0., 0., 0.,  ..., 0., 0., 0.], ) ================
Traceback (most recent call last):
  File "/workspace/Megatron-LM/pretrain_gpt.py", line 264, in <module>
    pretrain(
  File "/workspace/Megatron-LM/megatron/training/training.py", line 349, in pretrain
    iteration, num_floating_point_operations_so_far = train(
  File "/workspace/Megatron-LM/megatron/training/training.py", line 1366, in train
    save_checkpoint_and_time(iteration, model, optimizer,
  File "/workspace/Megatron-LM/megatron/training/training.py", line 1070, in save_checkpoint_and_time
    save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
  File "/workspace/Megatron-LM/megatron/training/checkpointing.py", line 380, in save_checkpoint
    optimizer.save_parameter_state(optim_checkpoint_name)
  File "/workspace/Megatron-LM/megatron/core/optimizer/distrib_optimizer.py", line 902, in save_parameter_state
    state_dict = self.get_parameter_state_dp_zero()
  File "/workspace/Megatron-LM/megatron/core/optimizer/distrib_optimizer.py", line 852, in get_parameter_state_dp_zero
    tensors[key].detach().cpu()
KeyError: 'exp_avg'

The above tensors is optimizer parameter, and it seems that optimizer state is None.

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants