Skip to content

Commit

Permalink
[Unified Checkpoint] Add split param and refactor code (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…9240)

* [Unified checkpoint] update optimizer async save signal

* update paddlepaddle

* split param

* add save for split param

* fix save split_param

* add load uc split_param

* update uc files

* update uc files

* update split_param loading

* mkdir unified_checkpoint directory

* rename file

* update async handler

* update files

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
  • Loading branch information
DesmonDay and gongel authored Oct 28, 2024
1 parent 81f5ab5 commit c9d5673
Show file tree
Hide file tree
Showing 15 changed files with 3,236 additions and 2,575 deletions.
2,569 changes: 0 additions & 2,569 deletions paddlenlp/trainer/plugins/unified_checkpoint.py

This file was deleted.

6 changes: 1 addition & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
from .argparser import strtobool
from .integrations import get_reporting_integration_callbacks
from .plugins.timer import RuntimeTimer, get_timers, set_timers
from .plugins.unified_checkpoint import UnifiedCheckpointHandler
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
Expand Down Expand Up @@ -144,6 +143,7 @@
speed_metrics,
)
from .training_args import TrainingArguments
from .unified_checkpoint import UnifiedCheckpointHandler
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver
from .utils.helper import ( # nested_truncate,
Expand Down Expand Up @@ -598,7 +598,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
if use_unified_checkpoint:
self.unified_checkpoint_handler.load_unified_checkpoint(
self.model,
self.optimizer,
resume_from_checkpoint,
)
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
Expand Down Expand Up @@ -1241,7 +1240,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
if self.args.unified_checkpoint:
self.unified_checkpoint_handler.load_unified_checkpoint(
self.model,
self.optimizer,
self.state.best_model_checkpoint,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
Expand Down Expand Up @@ -1289,7 +1287,6 @@ def _load_best_model_from_peft_checkpoint(self):
if self.args.unified_checkpoint:
self.unified_checkpoint_handler.load_unified_checkpoint(
self.model,
self.optimizer,
self.state.best_model_checkpoint,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
Expand Down Expand Up @@ -2775,7 +2772,6 @@ def _load_optimizer_and_scheduler(self, checkpoint):
opt_state_dict = None
else:
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
args=self.args,
model=self.model,
optimizer=self.optimizer,
resume_from_checkpoint=checkpoint,
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,12 @@ def is_segment_parallel_supported():
f"but got logging_steps={self.logging_steps}."
)

if "split_param" in sharding_parallel_config:
assert self.sharding == [ShardingOption.SHARD_OP], "Only sharding stage1 support split_param."
assert (
self.amp_master_grad
), "If `split_param` in sharding_parallel_config, `amp_master_grad` must be True."

fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .unified_checkpoint import UnifiedCheckpointHandler
250 changes: 250 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/async_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Asynchronous unified checkpoint handler."""

import multiprocessing
import os
import time
from multiprocessing import shared_memory

import paddle
import paddle.distributed as dist

from paddlenlp.transformers.utils import is_safetensors_available
from paddlenlp.utils.log import logger

if is_safetensors_available():
from safetensors.numpy import save_file as safe_save_file

from .shared_memory_utils import (
_read_state_dict_from_shm,
_traverse_copy_to_shm,
create_meta_dict,
)

__all__ = ["AsyncCheckpointHandler"]


class AsyncCheckpointHandler:
def __init__(self, args):
# Mainly for asynchronous saving.
self.args = args
self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1

self._shm_model_weight = None
self._shm_master_weight = None
self._shm_optimizer_weight = None
self._meta_dict_model = None
self._meta_dict_master_weight = None
self._meta_dict_optim = None
self._process_model_weight = None
self._process_master_weight = None
self._process_optimizer_weight = None
self._lock = None
self._shared_save_model_flag = None
self._shared_save_master_weight_flag = None
self._shared_save_optimizer_flag = None

if "async_save" in self.args.unified_checkpoint_config:
self._lock = multiprocessing.Lock()
self._shared_save_model_path = multiprocessing.Array("c", 100000)
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_model_flag = multiprocessing.Array("i", 1)
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
):
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()
safe_save_file(state_dict, path, metadata={"format": "np"})
else:
if state_dict_type == "model_weight":
if self._shm_model_weight is None:
self._meta_dict_model, buffer_size = create_meta_dict(state_dict)
self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_model_weight
meta_dict = self._meta_dict_model
shared_save_flag = self._shared_save_model_flag
shared_save_path = self._shared_save_model_path
shared_save_signal_path = self._shared_save_model_signal_path
if self._process_model_weight is None:
self._process_model_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_model_weight.name,
self._shared_save_model_flag,
self._shared_save_model_path,
self._shared_save_model_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_model_weight.start()
process = self._process_model_weight
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_master_weight
meta_dict = self._meta_dict_master_weight
shared_save_flag = self._shared_save_master_weight_flag
shared_save_path = self._shared_save_master_weight_path
shared_save_signal_path = self._shared_save_master_weight_signal_path
if self._process_master_weight is None:
self._process_master_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_master_weight.name,
self._shared_save_master_weight_flag,
self._shared_save_master_weight_path,
self._shared_save_master_weight_signal_path,
self._lock,
"model_weight"
if "skip_save_model_weight" in self.args.unified_checkpoint_config
else state_dict_type,
self.global_rank,
),
)
self._process_master_weight.start()
process = self._process_master_weight
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_optimizer_weight
meta_dict = self._meta_dict_optim
shared_save_flag = self._shared_save_optimizer_flag
shared_save_path = self._shared_save_optimizer_path
shared_save_signal_path = self._shared_save_optimizer_signal_path
if self._process_optimizer_weight is None:
self._process_optimizer_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_optimizer_weight.name,
self._shared_save_optimizer_flag,
self._shared_save_optimizer_path,
self._shared_save_optimizer_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
# only save model weight or save master weight, we enter this loop.
self._reset_and_update(shared_save_path, path)
self._reset_and_update(shared_save_signal_path, signal_path)
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
with self._lock:
shared_save_flag[0] = 1

def _save_file_async_in_process(
self,
meta_dict,
shm_name,
shared_save_flag,
shared_save_path,
shared_save_signal_path,
lock,
state_dict_type,
global_rank,
):
shm = shared_memory.SharedMemory(name=shm_name)
while True:
flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value.
if flag_value == -1: # stop process
break
if flag_value == 0: # nothing to save
continue
if flag_value == 1: # need to save
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
paddle.save(global_rank, saved_signal_path)
with lock:
shared_save_flag[0] = 0
time.sleep(0.5)
shm.close()

def _reset_and_update(self, shared_array, new_value):
# clear array
for i in range(len(shared_array)):
shared_array[i] = b"\0"
# update array
encoded_value = new_value.encode("utf-8")
shared_array[: len(encoded_value)] = encoded_value

def unlink_shared_memory(self):
if not ("async_save" in self.args.unified_checkpoint_config):
return

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

if self._shm_model_weight is not None:
self._shm_model_weight.close()
self._shm_model_weight.unlink()
self._shm_model_weight = None
if self._shm_master_weight is not None:
self._shm_master_weight.close()
self._shm_master_weight.unlink()
self._shm_master_weight = None
if self._shm_optimizer_weight is not None:
self._shm_optimizer_weight.close()
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

if paddle.distributed.get_world_size() > 1:
dist.barrier()
Loading

0 comments on commit c9d5673

Please sign in to comment.