Skip to content

Commit

Permalink
[Fluid] Fix fluid sync params buffers. (#4878)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI authored Feb 22, 2023
1 parent 75f3130 commit 9a00c15
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
49 changes: 29 additions & 20 deletions examples/language_model/gpt-3/dygraph/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import math
import os
import random
import time
import sys
import time

import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
DygraphShardingOptimizer,
)
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients,
)
from paddle.distributed.sharding import group_sharded_parallel
from visualdl import LogWriter
from modeling import GPTModel, GPTForPretraining, GPTPretrainingCriterion, GPTForPretrainingPipe
from paddlenlp.transformers import GPTTokenizer, GPTChineseTokenizer
from paddlenlp.utils.log import logger

from paddlenlp.transformers import GPTChineseTokenizer, GPTTokenizer
from paddlenlp.utils import profiler
from paddlenlp.utils.log import logger

# to import data_tools
filepath = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, os.path.join(filepath, "../"))

from dataset import create_pretrained_dataset
from args import parse_args
import lr
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer
from paddle.fluid.dygraph.parallel import sync_params_buffers
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

# add sharding stage2/3
from paddle.distributed.sharding import group_sharded_parallel
import lr # noqa e402
from args import parse_args # noqa e402
from dataset import create_pretrained_dataset # noqa e402
from modeling import ( # noqa e402
GPTForPretraining,
GPTForPretrainingPipe,
GPTModel,
GPTPretrainingCriterion,
)

MODEL_CLASSES = {
"gpt": (GPTForPretraining, GPTTokenizer),
Expand Down Expand Up @@ -268,6 +272,11 @@ def do_train(args):
# TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
if args.sharding_stage in [2, 3]:
if args.dp_degree > 1:
try:
from paddle.fluid.dygraph.parallel import sync_params_buffers
except ImportError:
from paddle.distributed.parallel import sync_params_buffers

sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])

scaler = scaler if args.use_pure_fp16 else None
Expand All @@ -287,7 +296,7 @@ def do_train(args):
logger.warning("No optimizer checkpoint file found in %s." % opt_path)

global_step = 0
tic_train = time.time()
# tic_train = time.time()
for epoch in range(args.num_train_epochs):
files = get_train_data_file(args)
files.sort()
Expand Down Expand Up @@ -414,7 +423,7 @@ def do_train(args):
log_writer.add_scalar("loss", float(loss), global_step)
log_writer.add_scalar("learning_rate", optimizer.get_lr(), global_step)

tic_train = time.time()
# tic_train = time.time()
train_reader_cost = 0.0
train_run_cost = 0.0

Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients,
)
from paddle.fluid.dygraph.parallel import sync_params_buffers
from paddle.io import DataLoader, Dataset, DistributedBatchSampler
from tqdm.auto import tqdm

Expand Down Expand Up @@ -1201,6 +1200,12 @@ def _wrap_model(self, model, training=True):
else:
# sync params (broadcast) buffers in dp group
if self.args.dp_degree > 1:
try:
from paddle.fluid.dygraph.parallel import sync_params_buffers
except ImportError:
# fix for new api in paddlepaddle v2.5
from paddle.distributed.parallel import sync_params_buffers

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])
Expand Down

0 comments on commit 9a00c15

Please sign in to comment.