From 49520692fb80898fd53ab708831ae54bb53073b7 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 4 Jan 2024 20:20:02 +0800 Subject: [PATCH] [GPT-3] Fix shared weights sync for PipelineLayer (#7775) --- paddlenlp/transformers/model_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 75c455850044..d420e72f4317 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -42,7 +42,10 @@ ) from huggingface_hub.utils import EntryNotFoundError from paddle import Tensor -from paddle.distributed.fleet.meta_parallel.parallel_layers import SharedLayerDesc +from paddle.distributed.fleet.meta_parallel.parallel_layers import ( + PipelineLayer, + SharedLayerDesc, +) from paddle.nn import Embedding, Layer # TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later @@ -935,6 +938,18 @@ def _post_init(self, original_init, *args, **kwargs): ): self.init_weights() + # Note: + # 1. PipelineLayer will create parameters for each layer and + # call `_synchronize_shared_weights()` to synchronize the shared parameters. + # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to + # synchronize the shared parameters. + # However, `self._init_weights` will re-initialize the parameters without + # synchronizing the shared parameters. If the following step does not load a checkpoint, + # the shared parameters will be different. + + if isinstance(self, PipelineLayer): + self._synchronize_shared_weights() + def _init_weights(self, layer): """ Initialize the weights. This method should be overridden by derived class.