Skip to content

Commit

Permalink
update split_param loading
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 24, 2024
1 parent 4ab0df1 commit ff0ebc2
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
for shard_file in resolved_archive_file:
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
continue

if model.config.tensor_parallel_degree > 1:
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")

Check warning on line 228 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L224-L228

Added lines #L224 - L228 were not covered by tests
else:
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")

state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
returned_state_dict.update(state_dict)
del state_dict
gc.collect()

Check warning on line 233 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L230-L233

Added lines #L230 - L233 were not covered by tests
Expand All @@ -238,13 +236,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected

# get tp params
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)

Check warning on line 238 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L238

Added line #L238 was not covered by tests
if has_master_weights:
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw,
sharded_metadata_mw,
expected_keys,
is_master_weights=True,
)

# need to split param for different sharding rank, maybe need to deal with oom issue.
for key in list(state_dict_optim.keys()):
Expand All @@ -266,15 +257,24 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype),
)
)

if has_master_weights:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])

Check warning on line 261 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L260-L261

Added lines #L260 - L261 were not covered by tests
else:
key_name = "_".join([static_name, key_name[1]])

Check warning on line 263 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L263

Added line #L263 was not covered by tests

state_dict_optim[key] = state_dict_optim[key]._copy_to(paddle.framework._current_expected_place(), False)

Check warning on line 265 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L265

Added line #L265 was not covered by tests

returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
returned_optim_state_dict[key_name].name = key_name

Check warning on line 268 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L267-L268

Added lines #L267 - L268 were not covered by tests

if has_master_weights:
state_dict_master_weight = load_resolved_archive_file(

Check warning on line 271 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L270-L271

Added lines #L270 - L271 were not covered by tests
resolved_archive_file_mw,
sharded_metadata_mw,
expected_keys,
is_master_weights=True,
)

for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings.get(key, None)
if state_dict_master_weight[key].numel().item() > 1:
Expand All @@ -292,6 +292,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype),
)
)
state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to(

Check warning on line 295 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L295

Added line #L295 was not covered by tests
paddle.framework._current_expected_place(), False
)
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

Check warning on line 299 in paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py#L298-L299

Added lines #L298 - L299 were not covered by tests

Expand Down

0 comments on commit ff0ebc2

Please sign in to comment.