Skip to content

Commit

Permalink
support save load optimizer master_weights (#60027)
Browse files Browse the repository at this point in the history
* exclude xpu

* dedup tensor in state_dict

* polish

* support flatten and unflatten state_dict

* test flatten

* rename test

* fix dedup tensor test

* fix test

* fix load state dict

* rename

* fix test

* support save load optimizer master weights

* add comment
  • Loading branch information
pangengzheng authored Dec 28, 2023
1 parent 65e2d93 commit 76ce9bb
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 57 deletions.
36 changes: 21 additions & 15 deletions python/paddle/distributed/checkpoint/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,9 @@ def load_state_dict(
assert isinstance(
state_dict, dict
), "The state_dict should be a dictionary."
state_dict = flatten_state_dict(state_dict)
if len(state_dict) > 0:
for val in state_dict.values():
flat_state_dict, mapping = flatten_state_dict(state_dict)
if len(flat_state_dict) > 0:
for val in flat_state_dict.values():
assert isinstance(
val, paddle.Tensor
), f"Only support dygraph Tensor now, but is {val}"
Expand All @@ -423,7 +423,7 @@ def load_state_dict(
paddle.distributed.barrier(process_group)

rank_to_files = get_rank_to_files(
path, state_dict, process_group, use_dist
path, flat_state_dict, process_group, use_dist
)
if len(rank_to_files) <= 0:
return
Expand All @@ -434,16 +434,18 @@ def load_state_dict(
)
# read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)],
# slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank.
read_items = get_read_items(path, state_dict, process_group, use_dist)
read_items = get_read_items(
path, flat_state_dict, process_group, use_dist
)
storage_file_to_state_dict = {}
logger.debug(
f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}"
f"before load, state_dict:{flat_state_dict},\n load_infos:{load_infos},\n read_items:{read_items}"
)
state_dict_in_cpu = []
for k, v in state_dict.items():
for k, v in flat_state_dict.items():
if v.place.is_cpu_place():
state_dict_in_cpu.append(k)
state_dict[k] = v.cuda()
flat_state_dict[k] = v.cuda()
for item in read_items:
assert (
item.local_tensor_index in load_infos
Expand Down Expand Up @@ -484,15 +486,17 @@ def load_state_dict(
# The read item rank need to be assigned
if item.rank == paddle.distributed.get_rank():
assert (
item.local_tensor_index.tensor_key in state_dict
), f"item:{item}, state_dict:{state_dict}"
item.local_tensor_index.tensor_key in flat_state_dict
), f"item:{item}, state_dict:{flat_state_dict}"
cur_local_tensor = (
state_dict[
flat_state_dict[
item.local_tensor_index.tensor_key
]._local_value()
if use_dist
and state_dict[item.local_tensor_index.tensor_key].is_dist()
else state_dict[item.local_tensor_index.tensor_key]
and flat_state_dict[
item.local_tensor_index.tensor_key
].is_dist()
else flat_state_dict[item.local_tensor_index.tensor_key]
)
cur_offsets = item.cur_offset
cur_lengths = item.lengths
Expand All @@ -513,7 +517,9 @@ def load_state_dict(
else:
cur_chunk_tensor = paddle.zeros(
item.lengths,
dtype=state_dict[item.local_tensor_index.tensor_key].dtype,
dtype=flat_state_dict[
item.local_tensor_index.tensor_key
].dtype,
)

if src_rank == item.rank:
Expand All @@ -530,6 +536,6 @@ def load_state_dict(
cur_chunk_tensor, src=src_rank, group=process_group
)

for k, v in state_dict.items():
for k, v in flat_state_dict.items():
if k in state_dict_in_cpu:
state_dict[k] = v.cpu()
1 change: 1 addition & 0 deletions python/paddle/distributed/checkpoint/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ class LocalTensorIndex:
class Metadata:
state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None
storage_metadata: Dict[LocalTensorIndex, str] = None
flat_mapping: Dict[str, Tuple[str]] = None
57 changes: 45 additions & 12 deletions python/paddle/distributed/checkpoint/save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
from typing import List

import paddle
from paddle.distributed.communication.group import is_initialized
Expand Down Expand Up @@ -50,7 +49,7 @@ def check_file_name(file_name, process_group):

def merge_state_dict_metadata(global_state_dict_metadata):
assert isinstance(
global_state_dict_metadata, List
global_state_dict_metadata, list
), "The global_state_dict should be a list."
out = {}
for state_dict in global_state_dict_metadata:
Expand All @@ -64,7 +63,7 @@ def merge_state_dict_metadata(global_state_dict_metadata):
return out


def dedup_storage_metadata(global_storage_metadata):
def dedup_key_in_dict(global_storage_metadata):
out = {}
for storage_metadata in global_storage_metadata:
for key, val in storage_metadata.items():
Expand All @@ -74,6 +73,34 @@ def dedup_storage_metadata(global_storage_metadata):
return out


def dedup_tensor(
local_state_dict, local_storage_metadata, global_storage_metadata
):
"""
Dedup the replicated tensor in local state_dict.
Args:
local_state_dict(Dict[str, paddle.Tensor]): The state_dict of current rank.
local_storage_metadata(Dict[LocalTensorIndex, str]): The storage metadata of current rank.
global_storage_metadata(Dict[LocalTensorIndex, str]): The final storage metadata of all ranks.
Examples:
In rank0, local_state_dict:{"w1": t1_0, "w2": t2}, local_storage_metadata:{LocalTensorIndex("w1", (0,0)): "0_0.distcp", LocalTensorIndex("w2", (0,0)): "0_0.distcp"},
in rank1, local_state_dict:{"w1": t1_1, "w2": t2}, local_storage_metadata:{LocalTensorIndex("w1", (1,0)): "1_0.distcp", LocalTensorIndex("w2", (0,0)): "1_0.distcp"},
global_storage_metadata:{LocalTensorIndex("w1", (0,0)): "0_0.distcp", LocalTensorIndex("w1", (1,0)): "1_0.distcp", LocalTensorIndex("w2", (0, 0)): "0_0.distcp"}.
w2 is replicated in rank0 and rank1. We save it in rank0 as default thus need to remove it in other ranks.
Finally, the local_state_dict:{"w1": t1_1, "w2": t2} in rank1 update to {"w1": t1_1}.
"""

for tensor_index, file_name in global_storage_metadata.items():
rank = int(file_name.split(".")[0].split("_")[0])
if (
tensor_index in local_storage_metadata
and rank != paddle.distributed.get_rank()
):
local_state_dict.pop(tensor_index.tensor_key)


def save_state_dict(
state_dict,
path,
Expand Down Expand Up @@ -107,9 +134,9 @@ def save_state_dict(
assert isinstance(
state_dict, dict
), "The state_dict should be a dictionary."
state_dict = flatten_state_dict(state_dict)
if len(state_dict) > 0:
for val in state_dict.values():
flat_state_dict, mapping = flatten_state_dict(state_dict)
if len(flat_state_dict) > 0:
for val in flat_state_dict.values():
assert isinstance(
val, paddle.Tensor
), "Only support dygraph Tensor now, support static DistributedTensor later"
Expand All @@ -134,12 +161,12 @@ def save_state_dict(
if use_dist:
check_file_name(file_name, process_group)
# the parameter_name and order in state_dict should be the same
check_state_dict(state_dict, process_group)
check_state_dict(flat_state_dict, process_group)
metadata = Metadata()
local_state_dict = {}
local_state_dict_metadata = {}
local_storage_metadata = {}
for key, val in state_dict.items():
for key, val in flat_state_dict.items():
if isinstance(val, paddle.Tensor):
# Case1: not initialized means this tensor is placed in another mesh which do not contain this rank
if not val._is_initialized():
Expand Down Expand Up @@ -178,6 +205,7 @@ def save_state_dict(
] = file_name
global_state_dict_metadata = []
global_storage_metadata = []
global_flatten_mapping = []
if use_dist:
paddle.distributed.all_gather_object(
global_state_dict_metadata,
Expand All @@ -187,19 +215,24 @@ def save_state_dict(
paddle.distributed.all_gather_object(
global_storage_metadata, local_storage_metadata, process_group
)
paddle.distributed.all_gather_object(
global_flatten_mapping, mapping, process_group
)
else:
global_state_dict_metadata.append(local_state_dict_metadata)
global_storage_metadata.append(local_storage_metadata)
global_flatten_mapping.append(mapping)

metadata.state_dict_metadata = merge_state_dict_metadata(
global_state_dict_metadata
)
metadata.storage_metadata = dedup_storage_metadata(
global_storage_metadata
)
metadata.storage_metadata = dedup_key_in_dict(global_storage_metadata)
metadata.flat_mapping = dedup_key_in_dict(global_flatten_mapping)
if coordinator_rank == paddle.distributed.get_rank():
logger.debug(f"metadata:{metadata}")
paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata"))
logger.debug(f"local_state_dict:{local_state_dict}")
# TODO(pangengzheng): del the replicated tensor in local_state_dict, now different might save the replicated tensor
dedup_tensor(
local_state_dict, local_storage_metadata, metadata.storage_metadata
)
paddle.save(local_state_dict, os.path.join(path, file_name))
44 changes: 43 additions & 1 deletion python/paddle/distributed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,47 @@ def compute_local_shape_and_global_offset(


def flatten_state_dict(state_dict):
# TODO, {"model": {"w0": xxx}} -> {model.w0: xxx}
"""
Flatten the nested dict to a flat dict.
{"model": {"w0": xxx}} -> {model.w0: xxx}
"""
flatten_state_dict = {}
mapping = {}

def _flatten(key, value):
if isinstance(value, dict):
for k, v in value.items():
assert isinstance(k, str), f"The key should be str, but is {k}"
_flatten(key + (k,), v)
elif isinstance(value, paddle.Tensor):
flatten_key_str = ".".join(key)
flatten_state_dict[flatten_key_str] = value
mapping[flatten_key_str] = key
else:
raise ValueError(
f"The value should be dict or paddle.Tensor, but is {value}"
)

_flatten((), state_dict)

return flatten_state_dict, mapping


def unflatten_state_dict(flat_state_dict, mapping):
"""
Unflatten the flat dict to a nested dict.
{model.w0: xxx} -> {"model": {"w0": xxx}}
"""
state_dict = {}
for key, value in flat_state_dict.items():
key_tuple = mapping[key]
assert isinstance(
key_tuple, tuple
), f"The key should be tuple, but is {key_tuple}"
tmp = state_dict
for i in range(len(key_tuple) - 1):
key = key_tuple[i]
tmp = tmp.setdefault(key, {})
tmp[key_tuple[-1]] = value

return state_dict
30 changes: 1 addition & 29 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,35 +406,7 @@ def set_state_dict(self, state_dict):
tensor.set_xpu_scale_value(
state_dict.get(var_tmp.name + ".SCALE_VALUE", -1.0)
)

model_np = np.array(tensor)

load_para = state_dict[var_tmp.name]

if isinstance(load_para, Variable):
load_para_np = np.array(load_para)
elif isinstance(load_para, core.eager.Tensor):
load_para_np = np.array(load_para)
elif isinstance(load_para, np.ndarray):
load_para_np = load_para
else:
raise RuntimeError(
f"State dict type {str(type(load_para))} not supprt"
)

assert (
model_np.shape == load_para_np.shape
), "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format(
model_np.name, model_np.shape, load_para_np.shape
)

assert (
model_np.dtype == load_para_np.dtype
), "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
model_np.name, model_np.dtype, load_para_np.dtype
)

tensor.set(load_para_np, framework._current_expected_place())
var.set_value(state_dict[var_tmp.name])

def get_opti_var_name_list(self):
return self._opti_name_list
Expand Down
3 changes: 3 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_gpt_with_prim MODULES test_gpt_with_prim)
set_tests_properties(test_gpt_with_prim
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 200)
py_test_modules(test_dist_checkpoint_utils MODULES test_dist_checkpoint_utils)
set_tests_properties(test_dist_checkpoint_utils
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_semi_auto_parallel_unshard_dtensor MODULES
test_semi_auto_parallel_unshard_dtensor)
set_tests_properties(test_semi_auto_parallel_unshard_dtensor
Expand Down
68 changes: 68 additions & 0 deletions test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist


class TestSaveStateDict:
def __init__(self):
self._ckpt_path = os.getenv("ckpt_path")

def test_dedup_tesnor(self):
w1 = paddle.arange(32).reshape([4, 8])
w2 = paddle.arange(32, 36).reshape([2, 2])
mesh = dist.ProcessMesh([0, 1])
dist_w1 = dist.shard_tensor(w1, mesh, [dist.Replicate()])
dist_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)])
state_dict = {"w1": dist_w1, "w2": dist_w2}
# w1 is replicated in rank0 and ran1, it will only save in rank0.
# Therefore, rank0 save state_dict:{"w1": dist_w1, "w2": dist_w2}, rank1 save state_dict:{"w2": dist_w2}
dist.save_state_dict(state_dict, self._ckpt_path)
paddle.distributed.barrier()
# check
expect_local_state_dict = {}
for k, v in state_dict.items():
if k == "w1" and paddle.distributed.get_rank() != 0:
continue
expect_local_state_dict[k] = v._local_value()
data_file_path = os.path.join(
self._ckpt_path, f"{paddle.distributed.get_rank()}_0.distcp"
)
metadata_file_path = os.path.join(self._ckpt_path, "0.metadata")
assert os.path.exists(data_file_path) and os.path.exists(
metadata_file_path
)
local_state_dict = paddle.load(data_file_path)
metadata = paddle.load(metadata_file_path)

for k, local_tensor in local_state_dict.items():
assert k in expect_local_state_dict
expect_tensor = expect_local_state_dict[k]
np.testing.assert_equal(expect_tensor.numpy(), local_tensor.numpy())
for tensor_index, file_name in metadata.storage_metadata.items():
rank = int(file_name.split(".")[0].split("_")[0])
if tensor_index.tensor_key == "w1":
assert rank == 0

def run_test_case(self):
self.test_dedup_tesnor()


if __name__ == '__main__':
TestSaveStateDict().run_test_case()
Loading

0 comments on commit 76ce9bb

Please sign in to comment.