From 16bca68ba08f503222ca4c207e8a527fb4d68449 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Tue, 23 Jan 2024 03:25:34 +0000 Subject: [PATCH 1/7] support semi-auto trainer and fit Llama2 training --- .../auto_parallel/run_pretrain_3D_auto.py | 318 +++++++++--------- paddlenlp/trainer/auto_trainer.py | 183 ++++++++++ paddlenlp/trainer/trainer.py | 305 +++++++++-------- paddlenlp/trainer/trainer_utils.py | 19 ++ paddlenlp/trainer/training_args.py | 18 +- .../transformers/llama/modeling_3D_auto.py | 112 ++---- 6 files changed, 569 insertions(+), 386 deletions(-) create mode 100644 paddlenlp/trainer/auto_trainer.py diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.py b/llm/llama/auto_parallel/run_pretrain_3D_auto.py index 1d8bbe8b73ea..1fa545555b74 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.py @@ -18,6 +18,7 @@ import random import sys import types +from collections import OrderedDict from dataclasses import dataclass, field from typing import List, Optional @@ -25,9 +26,11 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet -from paddle.io import DataLoader, DistributedBatchSampler -from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments +from paddlenlp.ops import Topology +from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint +from paddlenlp.trainer.auto_trainer import SemiAutoTrainer +from paddlenlp.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds from paddlenlp.transformers import ( AutoTokenizer, CosineAnnealingWithWarmupDecay, @@ -42,8 +45,6 @@ } -from collections import OrderedDict - from paddlenlp.data.causal_dataset import ( build_train_valid_test_datasets, check_data_split, @@ -78,11 +79,59 @@ class PreTrainingArguments(TrainingArguments): "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." }, ) - parallel_mode: str = field(default="hybrid", metadata={"help": ""}) - + fused_linear_param_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation." + }, + ) + job_schedule_profiler_start: int = field( + default=-1, + metadata={"help": "The step to start job_schedule_profiler."}, + ) + job_schedule_profiler_end: int = field( + default=-1, + metadata={"help": "The step to end job_schedule_profiler."}, + ) pipeline_schedule_mode: str = field( default="1F1B", metadata={"help": "The pipeline schedule mode, support FThenB, 1F1B, VPP and Eager-1F1B."} ) + sr: Optional[int] = field(default=0, metadata={"help": "The count of chunks without recompute."}) + refined_ops_patterns: Optional[List[str]] = field( + default=None, metadata={"help": "The pattern of refined recompute."} + ) + virtual_pipeline_seg_method: str = field( + default="LlamaDecoderLayerAuto", metadata={"help": "The seg method of spliting pp layer for virtual pipeline."} + ) + # NOTE(gongenlei): new add autotuner_benchmark + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, + ) + + def __post_init__(self): + super().__post_init__() + assert self.use_auto_parallel + + # NOTE(gongenlei): new add autotuner_benchmark + if self.autotuner_benchmark: + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + + if self.fused_linear_param_grad_add: + fused_passes = self.strategy.fused_passes + fused_passes.enable = True + fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass") + + logger.info(self.strategy) @dataclass @@ -244,7 +293,6 @@ def create_pretrained_dataset( print_rank_0(" test: {}".format(train_val_test_num_samples[2])) # Build the datasets. - print("====data seed====", training_args.seed) train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( data_prefix=data_file, data_impl=data_args.data_impl, @@ -314,29 +362,17 @@ def get_train_data_file(args): return files -def create_optimizer(model, lr_scheduler, training_args): - decay_parameters = [ - p.name - for n, p in model.named_parameters() - if (not any(nd in n for nd in ["bias", "norm"])) or n == "llama.norm.weight" - ] - - def apply_decay_param_fun(x): - return x in decay_parameters +class PretrainingTrainer(SemiAutoTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) - optimizer = optimizer_cls( - learning_rate=lr_scheduler if lr_scheduler is None else lr_scheduler, - apply_decay_param_fun=apply_decay_param_fun, - parameters=model.parameters(), - weight_decay=training_args.weight_decay, - grad_clip=paddle.nn.ClipGradByGlobalNorm(training_args.max_grad_norm) - if training_args.max_grad_norm > 0 - else None, - **optimizer_kwargs, - ) - - return optimizer + def _wrap_dist_loader(self, train_dataloader): + return dist.shard_dataloader( + dataloader=train_dataloader, + meshes=self._get_meshes_for_loader(), + input_keys=["input_ids", "labels"], + shard_dims="dp", + ) def print_config(args, key=""): @@ -367,24 +403,28 @@ def init_seed(seed: int = 1234, args=None): random.seed(seed) np.random.seed(seed) paddle.seed(seed) - - if args is not None: - if args.use_hybrid_parallel: - from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker - - random.seed(args.seed + args.dataset_rank) - np.random.seed(args.seed + args.dataset_rank) - paddle.seed(args.seed + args.dataset_rank) - - # local_seed/ global_seed is used to control dropout in ModelParallel - local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000 - global_seed = args.seed + 100003 + args.dataset_rank - tracker = get_rng_state_tracker() - - if "global_seed" not in tracker.states_: - tracker.add("global_seed", global_seed) - if "local_seed" not in tracker.states_: - tracker.add("local_seed", local_seed) + else: + assert not args.use_hybrid_parallel and args.use_auto_parallel + if dist.get_world_size() > 1: + topo = Topology( + dist.get_rank(), + dist.get_world_size(), + dp_degree=args.data_parallel_degree, + pp_degree=args.pipeline_parallel_degree, + mp_degree=args.tensor_parallel_degree, + sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp + ) + + global_seed, local_seed, random_seed = _get_distributed_seeds(args.seed, topo) + + paddle.seed(local_seed) + random.seed(random_seed) + np.random.seed(random_seed) + + logger.info( + "The global seed is set to {}, local seed is set to {} and " + "random seed is set to {}.".format(global_seed, local_seed, random_seed) + ) else: random.seed(args.seed) np.random.seed(args.seed) @@ -440,6 +480,16 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + config_class, model_class = MODEL_CLASSES[model_args.model_type] tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) @@ -458,6 +508,7 @@ def main(): if model_args.no_recompute_layers is not None: model_args.no_recompute_layers.sort() + config.vocab_size = model_args.vocab_size if model_args.vocab_size is not None else config.vocab_size config.hidden_size = model_args.hidden_size if model_args.hidden_size is not None else config.hidden_size config.intermediate_size = ( model_args.intermediate_size if model_args.intermediate_size is not None else config.intermediate_size @@ -486,6 +537,11 @@ def main(): config.tensor_parallel_degree = training_args.tensor_parallel_degree config.tensor_parallel_rank = training_args.tensor_parallel_rank + if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1: + pipeline = training_args.strategy.pipeline + pipeline.vpp_degree = config.virtual_pp_degree + pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method + print("Final pre-training config:", config) # Set the dtype for loading model @@ -496,11 +552,23 @@ def main(): if training_args.bf16: dtype = "bfloat16" - print("======M M M M======", model_class) - model = model_class._from_config(config, dtype=dtype) - # load model - # load_model(model) - shard_model(model) + with paddle.LazyGuard(): + model = model_class.from_config(config, dtype=dtype) + + criterion = None + + for param in model.parameters(): + assert not param._is_initialized() + param.initialize() + + if training_args.recompute: + + def fn(layer): + if hasattr(layer, "enable_recompute") and (layer.enable_recompute is False or layer.enable_recompute == 0): + layer.enable_recompute = True + + model.apply(fn) + # Create the learning_rate sheduler and optimizer if training_args.decay_steps is None: training_args.decay_steps = training_args.max_steps @@ -525,7 +593,7 @@ def main(): ) data_file = get_train_data_file(data_args) - train_dataset, _, _, data_collator = create_pretrained_dataset( + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( data_args, training_args, data_file, @@ -533,109 +601,49 @@ def main(): need_data=training_args.should_load_dataset, ) - optimizer = create_optimizer(model, lr_scheduler, training_args) - - def loss_func(loss, outputs): - return loss - - print_config(training_args) - - # create sampler and dataloader - # each rank read (training_args.per_device_train_batch_size * training_args.data_parallel_degree) samples - print( - "dp_rank: ", dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree) - ) - print( - f"===> worldsize = {training_args.per_device_train_batch_size} rank: {dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree)}" - ) - train_sampler = DistributedBatchSampler( - train_dataset, - batch_size=training_args.per_device_train_batch_size, - shuffle=False, - num_replicas=training_args.data_parallel_degree, - rank=dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree), - drop_last=training_args.dataloader_drop_last, - ) - - train_dataloader = DataLoader( - train_dataset, - batch_sampler=train_sampler, - collate_fn=data_collator, - num_workers=training_args.dataloader_num_workers, - ) - - num_update_steps_per_epoch = len(train_dataloader) // training_args.gradient_accumulation_steps - num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) - num_train_epochs = training_args.max_steps // num_update_steps_per_epoch + int( - training_args.max_steps % num_update_steps_per_epoch > 0 - ) - - global_step = 1 - tr_loss = float(0) - - # hack: create dp group for distributed input data to align dygraph parallel loss. - dp_group = None - global_mesh = fleet.auto.get_mesh().get_mesh_with_dim("pp").mesh - mesh_shape = global_mesh.shape - for id in range(mesh_shape[0]): - pp_mesh = global_mesh[id] - for i in range(pp_mesh.shape[-1]): - ranks = pp_mesh[:, i] - print("dp ranks: ", ranks) - group = dist.new_group(ranks) - if dist.get_rank() in ranks: - dp_group = group - assert dp_group is not None - - model.train() - optimizer = dist.shard_optimizer(optimizer) - for epoch_idx in range(num_train_epochs): - for step, inputs in enumerate(train_dataloader): - input_ids, labels = inputs["input_ids"], inputs["labels"] - - input_id = input_ids[0][0].numpy() - label = labels[0][0].numpy() - - # hack for align dygraph parallel. - if dp_group is not None: - cur_rank = dist.get_rank() - res = [] - dist.all_gather(res, paddle.Tensor(input_ids, place=paddle.CUDAPlace(cur_rank)), group=dp_group) - input_ids = paddle.concat(res) - input_ids = dist.shard_tensor(input_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) - - res = [] - dist.all_gather(res, paddle.Tensor(labels, place=paddle.CUDAPlace(cur_rank)), group=dp_group) - labels = paddle.concat(res) - labels = dist.shard_tensor(labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) - - res = model(input_ids, labels=labels) - - # add criterion in the future. - tr_loss_step = res[0] - - if training_args.gradient_accumulation_steps > 1: - tr_loss_step /= training_args.gradient_accumulation_steps - - # do backward every micro step. - tr_loss_step.backward() - tr_loss += tr_loss_step - - if global_step % training_args.gradient_accumulation_steps == 0: - # print_grad(model) - optimizer.step() - lr_scheduler.step() - # print_param(model) - optimizer.clear_grad() - print( - f"global_step {global_step // training_args.gradient_accumulation_steps};input id {input_id}; label {label}; loss {tr_loss.numpy()} lr: {optimizer.get_lr()}" - ) - tr_loss = 0 - - if global_step // training_args.gradient_accumulation_steps >= training_args.max_steps: - break + # total_train_batch_size_per_acc_step = ( + # training_args.per_device_train_batch_size * training_args.data_parallel_degree + # ) + # total_train_batch_size = total_train_batch_size_per_acc_step * training_args.gradient_accumulation_steps + + trainer = PretrainingTrainer( + model=model, + criterion=criterion, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + # NOTE(gongenlei): new add + if not training_args.autotuner_benchmark: + metrics = train_result.metrics + if not int(os.getenv("test_ci_no_save_model", 0)): + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) - global_step += 1 + # if training_args.should_load_dataset: + # effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + # print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") + # print(f"ips: {effective_tokens_per_second:.2f} tokens/s") def shard_model(model): diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py new file mode 100644 index 000000000000..3a16f9c7b3f6 --- /dev/null +++ b/paddlenlp/trainer/auto_trainer.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Union + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn as nn +from paddle.distributed import fleet + +from paddlenlp.trainer import Trainer + +from ..utils.log import logger +from .trainer_utils import _exec_mode_guard, has_length + + +class SemiAutoTrainer(Trainer): + def __init__(self, *args, **kwargs): + + if kwargs.get("args", None) is not None and kwargs["args"].run_static_semi_auto: + if kwargs.get("criterion", None) is None: + + def loss_func(loss, outputs): + return loss + + kwargs.update({"criterion": loss_func}) + + super().__init__(*args, **kwargs) + assert self.args.use_auto_parallel + + def _nested_gather(self, tensors): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + return tensors + + def _wrap_model(self, model, training=True): + self.optimizer = dist.shard_optimizer(self.optimizer) if not self.args.run_static_semi_auto else self.optimizer + + return model + + def _get_meshes_for_loader(self): + def _get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp")[pp_idx] + return mesh + + meshes = [] + for pp_idx in range(self.args.pipeline_parallel_degree): + meshes.append(_get_mesh(pp_idx)) + return meshes + + def _wrap_dist_loader(self, train_dataloader): + return dist.shard_dataloader( + dataloader=train_dataloader, + meshes=self._get_meshes_for_loader(), + shard_dims="dp", + ) + + def _wrap_for_static(self, model, train_dataloader): + # TODO: convert fleet.auto.Strategy to dist.Strategy + # TODO: fix bugs in paddle/distributed/auto_parallel/api.py#L981 about sample_split of engine._prepare_data_spec + + dist_loader = self._wrap_dist_loader(train_dataloader) + + model = dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy) + return model, dist_loader + + def _wrap_for_amp_training(self): + pass + + def _print_trainable_numel(self): + if not self.args.run_static_semi_auto: + super()._print_trainable_numel() + else: + per_device_trainable_numel = sum( + np.prod(p.shape) for p in self.model._engine._model.parameters() if not p.stop_gradient + ) + logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + + parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) + if parts_num > 1: + all_reduce_dtype = "int64" + if paddle.get_device().split(":")[0] in ["npu", "xpu"]: + # TODO(duanyanhui): fix when NPU all_reduce supports int64 + all_reduce_dtype = "float32" + + with _exec_mode_guard("dynamic"): + trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) + paddle.distributed.all_reduce(trainable_numel_tensor) + trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size + + if self.args.sep_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.sep_parallel_degree + # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited + # so, the trainable numel is a little bigger than real. + logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + total_batch_size_per_acc_step = self.args.per_device_train_batch_size * self.args.dataset_world_size + total_batch_size = total_batch_size_per_acc_step * self.args.gradient_accumulation_steps + batch_size = total_batch_size if self.args.run_static_semi_auto else total_batch_size_per_acc_step + + return paddle.io.BatchSampler( + dataset=self.train_dataset, + shuffle=True, + batch_size=batch_size, + drop_last=self.args.dataloader_drop_last, + ) + + # return DistributedBatchSampler( + # self.train_dataset, + # batch_size=self.args.per_device_train_batch_size, + # shuffle=True, + # num_replicas=self.args.dataset_world_size, + # rank=self.args.dataset_rank, + # drop_last=self.args.dataloader_drop_last, + # ) + + def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + model.train() + + inputs = self._prepare_inputs(inputs) + + if not self.args.run_static_semi_auto: + with self.autocast_smart_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + else: + loss.backward() + else: + input_ids, labels = tuple(inputs.values()) + loss = model(input_ids, labels) + + if self.args.pipeline_parallel_degree > 1: + self._pp_data_buffer = {} + + if loss is not None and self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + if isinstance(loss, paddle.Tensor): + return loss.detach() if loss._is_initialized() else float(0.0) + elif isinstance(loss, np.ndarray): + return np.sum(loss) + elif loss is None: + return float(0.0) + else: + return float(loss) + + def synchronize_gradients(self, *args, **kwargs): + pass + + def optimizer_step(self): + if not self.args.run_static_semi_auto: + super().optimizer_step() + else: + pass + + def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): + with _exec_mode_guard("dynamic"): + super()._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, **kwargs) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 70ce9033842b..48e4caa9aa02 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -131,6 +131,7 @@ ShardingOption, TrainerMemoryTracker, TrainOutput, + _exec_mode_guard, find_batch_size, get_last_checkpoint, get_scheduler, @@ -330,7 +331,8 @@ def __init__( "Passing `optimizers` is not allowed if sharding is enabled." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) - if self.args.pipeline_parallel_degree > 1: + + if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel: from paddle.distributed.fleet.meta_parallel import PipelineLayer assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance( @@ -344,6 +346,9 @@ def __init__( ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self._save_ckpt_func = dist.save_state_dict if self.args.use_auto_parallel else paddle.save + self._load_ckpt_func = dist.load_state_dict if self.args.use_auto_parallel else paddle.load + if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") @@ -358,46 +363,8 @@ def __init__( self.enable_autocast_context_manager = True self.do_grad_scaling = True if args.fp16 else False self.amp_dtype = "float16" if args.fp16 else "bfloat16" - # fix for load saved fp16 or bf16 ckpt, decorate model first. - if self.args.fp16_opt_level == "O2": - paddle.amp.decorate( - models=model, - level=self.args.fp16_opt_level, - dtype=self.amp_dtype, - excluded_layers=QuantizationLinear, - ) - # for pipeline mode and pure tensor parallel - if self.args.pipeline_parallel_degree > 1 or ( - self.args.tensor_parallel_degree > 1 and self.sharding is None - ): - self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss) - if self.args.amp_master_grad: - mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use - self.scaler = fleet.distributed_scaler(self.scaler) - elif self.sharding is not None: - self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss) - if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16": - if ShardingOption.SHARD_OP in self.args.sharding: - self.scaler = fleet.distributed_scaler(self.scaler) - if self.args.amp_master_grad: - mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use - else: - # scaler for stage2 and stage3 - from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( - GroupShardedScaler, - ) - - if self.args.amp_master_grad: - mix_precision_utils.MixPrecisionScaler(self.scaler) # return value has no use - - self.scaler = GroupShardedScaler(self.scaler) - else: - self.do_grad_scaling = False - self.use_cuda_amp = False - self.amp_dtype = None - - else: - self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss) + self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss) + self._wrap_for_amp_training() if args.recompute: @@ -456,6 +423,41 @@ def remove_callback(self, callback): """ self.callback_handler.remove_callback(callback) + def _wrap_for_amp_training(self): + # fix for load saved fp16 or bf16 ckpt, decorate model first. + if self.args.fp16_opt_level == "O2": + paddle.amp.decorate( + models=self.model, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + excluded_layers=QuantizationLinear, + ) + # for pipeline mode and pure tensor parallel + if self.args.pipeline_parallel_degree > 1 or (self.args.tensor_parallel_degree > 1 and self.sharding is None): + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use + self.scaler = fleet.distributed_scaler(self.scaler) + elif self.sharding is not None: + if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16": + if ShardingOption.SHARD_OP in self.args.sharding: + self.scaler = fleet.distributed_scaler(self.scaler) + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use + else: + # scaler for stage2 and stage3 + from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( + GroupShardedScaler, + ) + + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionScaler(self.scaler) # return value has no use + + self.scaler = GroupShardedScaler(self.scaler) + else: + self.do_grad_scaling = False + self.use_cuda_amp = False + self.amp_dtype = None + def _load_from_peft_checkpoint(self, resume_from_checkpoint=None): """load state_dict from checkpoint, Only for PEFT Model. @@ -718,6 +720,11 @@ def train( self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + if self.args.use_auto_parallel and self.args.run_static_semi_auto: + model, train_dataloader = self._wrap_for_static(model, train_dataloader) + + self.model = model + logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs}") @@ -728,24 +735,7 @@ def train( logger.info(f" Total num train samples = {num_train_samples:,}") # per_device_trainable_numel = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) # TODO: Temporary fix since Tensor.numel() not supported in distributed mode - per_device_trainable_numel = sum(np.prod(p.shape) for p in model.parameters() if not p.stop_gradient) - logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") - if self.args.use_hybrid_parallel: - # todo fix for pipeline_parallel_degree - parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) - if parts_num > 1: - all_reduce_dtype = "int64" - if paddle.get_device().split(":")[0] in ["npu", "xpu"]: - # TODO(duanyanhui): fix when NPU all_reduce supports int64 - all_reduce_dtype = "float32" - trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) - paddle.distributed.all_reduce(trainable_numel_tensor) - trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size - if self.args.sep_parallel_degree > 0: - trainable_numel = trainable_numel // self.args.sep_parallel_degree - # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited - # so, the trainable numel is a little bigger than real. - logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") + self._print_trainable_numel() start_time = time.time() self._globalstep_last_start_time = time.time() @@ -828,7 +818,8 @@ def train( self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - tr_loss = paddle.to_tensor(0.0) + with _exec_mode_guard("dynamic"): + tr_loss = paddle.to_tensor(0.0) self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step @@ -889,7 +880,7 @@ def train( forbidden_no_sync = False # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync - if self.args.use_hybrid_parallel: + if self.args.use_hybrid_parallel or self.args.use_auto_parallel: forbidden_no_sync = True availiable_no_sync = dp_enabled and not forbidden_no_sync @@ -916,12 +907,20 @@ def train( else: tr_loss_step = self.training_step(model, inputs) - tr_loss += tr_loss_step + with _exec_mode_guard("dynamic"): + tr_loss += tr_loss_step + + disable_accumulation = ( + self.args.use_auto_parallel + and self.args.pipeline_parallel_degree > 1 + and self.args.run_static_semi_auto + ) if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + or disable_accumulation ): if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss(): tr_loss /= self.args.gradient_accumulation_steps @@ -934,80 +933,19 @@ def train( # local_rank != -1 don't means dp in networks. self.timers and self.timers("all-reduce").start() - # Case 1: Use recompute and dp / sharding stage1, - # manualy collect gradient for dp. - if args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) - - # Case 2: hack dp with master_grad - if dp_master_grad and not (args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(model.parameters()), None) - - # Pipeline parallel mode, handle gradient reduce here to overlap - pipeline_parallel_config = ( - set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() - ) - enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config - enable_release_grads = "enable_release_grads" in pipeline_parallel_config - - # Case 3: Pipeline parallel mode, overlap with dp - if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: - parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) - - if not enable_dp_comm_overlap: - if self.optimizer._sharding_enable: - assert reshard_util.is_sharding_opt(self.optimizer) - self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) - - if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) + self.synchronize_gradients(availiable_no_sync, dp_master_grad) self.timers and self.timers("all-reduce").stop() self.timers and self.timers("optimizer-step").start() - if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss(): - for p in model._layers.parameters(): - with paddle.no_grad(): - if hasattr(p, "main_grad") and p.main_grad is not None: - assert p.grad is None - p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) - elif p.grad is not None: - p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) - # Optimizer step self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None ) - optimizer_was_run = True - if self.do_grad_scaling: - scale_before = paddle.assign(self.scaler._scale) - self.scaler.step(self.optimizer) - self.scaler.update() - scale_after = self.scaler._scale - optimizer_was_run = not self.scaler._cache_founf_inf - if not optimizer_was_run: - scale_before_value = scale_before.cpu().numpy() - scale_after_value = scale_after.cpu().numpy() - logger.warning( - f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" - ) - elif isinstance(self.optimizer, HybridParallelOptimizer): - self.optimizer._step(parameters_list) - else: - self.optimizer.step() - - self.timers and self.timers("optimizer-step").stop() - if optimizer_was_run: - self.lr_scheduler.step() + self.optimizer_step() - if enable_release_grads and args.pipeline_parallel_degree > 1: - self.optimizer.clear_grad(set_to_zero=False) - for _, buffers in model._chunk_2_comm_buffers.items(): - for buffer in buffers: - buffer._clear_grad_storage() - else: - self.optimizer.clear_grad() + self.timers and self.timers("optimizer-step").stop() self.callback_handler.on_optimizer_end( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None @@ -1071,7 +1009,7 @@ def train( "on multiple nodes, you should activate `--save_on_each_node`." ) - self._total_loss_scalar += tr_loss.item() + self._total_loss_scalar += self._get_item_from_loss(tr_loss) train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) @@ -1088,6 +1026,98 @@ def train( return TrainOutput(self.state.global_step, train_loss, metrics) + def _wrap_for_static(self, model, train_dataloader): + pass + + def _print_trainable_numel(self): + per_device_trainable_numel = sum(np.prod(p.shape) for p in self.model.parameters() if not p.stop_gradient) + logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + + if self.args.use_hybrid_parallel: + # todo fix for pipeline_parallel_degree + parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) + if parts_num > 1: + all_reduce_dtype = "int64" + if paddle.get_device().split(":")[0] in ["npu", "xpu"]: + # TODO(duanyanhui): fix when NPU all_reduce supports int64 + all_reduce_dtype = "float32" + trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) + paddle.distributed.all_reduce(trainable_numel_tensor) + trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size + if self.args.sep_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.sep_parallel_degree + # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited + # so, the trainable numel is a little bigger than real. + logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") + + def synchronize_gradients(self, availiable_no_sync, dp_master_grad): + # Case 1: Use recompute and dp / sharding stage1, + # manualy collect gradient for dp. + if self.args.recompute and availiable_no_sync: + fused_allreduce_gradients(list(self.model.parameters()), None) + + # Case 2: hack dp with master_grad + if dp_master_grad and not (self.args.recompute and availiable_no_sync): + fused_allreduce_gradients(list(self.model.parameters()), None) + + # Pipeline parallel mode, handle gradient reduce here to overlap + pipeline_parallel_config = ( + set(self.args.pipeline_parallel_config.split(" ")) if self.args.pipeline_parallel_degree > 1 else set() + ) + enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config + + # Case 3: Pipeline parallel mode, overlap with dp + if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: + parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) + + if not enable_dp_comm_overlap: + if self.optimizer._sharding_enable: + assert reshard_util.is_sharding_opt(self.optimizer) + self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) + + if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): + fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) + + def optimizer_step(self): + if self.args.pipeline_parallel_degree > 1 and self._enable_delay_scale_loss(): + for p in self.model._layers.parameters(): + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) + elif p.grad is not None: + p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) + + optimizer_was_run = True + if self.do_grad_scaling: + scale_before = paddle.assign(self.scaler._scale) + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler._scale + optimizer_was_run = not self.scaler._cache_founf_inf + if not optimizer_was_run: + scale_before_value = scale_before.cpu().numpy() + scale_after_value = scale_after.cpu().numpy() + logger.warning( + f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" + ) + elif isinstance(self.optimizer, HybridParallelOptimizer): + parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) + self.optimizer._step(parameters_list) + else: + self.optimizer.step() + + if optimizer_was_run: + self.lr_scheduler.step() + + if self.args.pipeline_parallel_degree > 1 and "enable_release_grads" in self.args.pipeline_parallel_config: + self.optimizer.clear_grad(set_to_zero=False) + for _, buffers in self.model._chunk_2_comm_buffers.items(): + for buffer in buffers: + buffer._clear_grad_storage() + else: + self.optimizer.clear_grad() + def _load_best_model_from_peft_checkpoint(self): convert_tp = False if isinstance(self.model, LoRAModel): @@ -1169,14 +1199,22 @@ def _print_timer(self): if timer_info or paddle_timer_info: logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}") + def _get_item_from_loss(self, loss): + if isinstance(loss, paddle.Tensor): + if loss.is_dist(): + return loss._local_value().item() if loss._is_initialized() else 0.0 + else: + return loss.item() if loss._is_initialized() else 0.0 + else: + return loss + def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): if self.control.should_log: logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes - tr_loss_scalar = self._nested_gather(tr_loss).mean().item() - + tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean()) # reset tr_loss to zero tr_loss.subtract_(tr_loss) @@ -2070,7 +2108,7 @@ def _save_checkpoint(self, model, metrics=None): safe_serialization=True, ) else: - paddle.save( + self._save_ckpt_func( self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name), ) @@ -2078,7 +2116,7 @@ def _save_checkpoint(self, model, metrics=None): if self.args.should_save: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") - paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -2272,7 +2310,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ logger.warning("Trainer.model is not a `PretrainedModel`, not suppor for merge_tensor_parallel.") if state_dict is None: state_dict = self.model.state_dict() - paddle.save( + + self._save_ckpt_func( state_dict, os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)), ) @@ -2646,7 +2685,7 @@ def evaluation_loop( metrics = {} if all_losses is not None: - metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + metrics[f"{metric_key_prefix}_loss"] = self._get_item_from_loss(all_losses.mean()) # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index fb54024281a5..d10c7f110388 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -29,6 +29,7 @@ import re import threading import time +from contextlib import contextmanager from enum import Enum from typing import Dict, List, NamedTuple, Optional, Tuple, Union @@ -164,6 +165,24 @@ def set_seed(seed: int = 1234, topo=None): ) +def _switch_mode(mode="dynamic"): + assert mode in ["dynamic", "static"] + if mode == "dynamic": + paddle.disable_static() + else: + paddle.enable_static() + + +@contextmanager +def _exec_mode_guard(mode="dynamic"): + origin_mode = "dynamic" if paddle.in_dynamic_mode() else "static" + _switch_mode(mode) + try: + yield + finally: + _switch_mode(origin_mode) + + class ExplicitEnum(Enum): """ Enum with more explicit error message for missing values. diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index f7a53a414e9b..0192301c2acc 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -747,6 +747,8 @@ class TrainingArguments: default=False, metadata={"help": "reshard pp even if pp degree in the model and pp degree in script match"}, ) + parallel_mode: str = field(default="hybrid", metadata={"help": ""}) + run_static_semi_auto: bool = field(default=True, metadata={"help": ""}) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) @@ -1140,8 +1142,8 @@ def is_segment_parallel_supported(): if len(self.sharding) > 0: self.sharding_parallel_degree = self.data_parallel_degree - sharding_parallel_degree = max(self.sharding_parallel_degree, 1) - if sharding_parallel_degree == 1 and len(self.sharding) > 0: + self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1) + if self.sharding_parallel_degree == 1 and len(self.sharding) > 0: logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") self.sharding = [] @@ -1226,10 +1228,10 @@ def is_segment_parallel_supported(): "by current version of Paddle. Please try latest develop Paddle." ) - if sharding_parallel_degree > 1: + if self.sharding_parallel_degree > 1: sharding = strategy.sharding sharding.enable = True - sharding.degree = sharding_parallel_degree + sharding.degree = self.sharding_parallel_degree if ShardingOption.SHARD_OP in self.sharding: sharding.stage = 1 elif ShardingOption.SHARD_GRAD_OP in self.sharding: @@ -1279,6 +1281,7 @@ def is_segment_parallel_supported(): mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree)))) if not mesh_dims: mesh_dims = [("dp", 1)] + fleet.auto.create_mesh(mesh_dims) else: world_size = paddle.distributed.get_world_size() @@ -1395,6 +1398,9 @@ def data_parallel_rank(self): if dp_group.rank == -1: return 0 return dp_group.rank + elif self.use_auto_parallel: + mesh = fleet.auto.get_mesh() + return mesh.get_dim_size("dp") else: return paddle.distributed.get_rank() @@ -1543,7 +1549,9 @@ def should_log(self): """ Whether or not the current process should produce log. """ - if self.log_on_each_node: + if self.use_auto_parallel: + return True + elif self.log_on_each_node: return self.local_process_index == 0 else: return self.process_index == 0 diff --git a/paddlenlp/transformers/llama/modeling_3D_auto.py b/paddlenlp/transformers/llama/modeling_3D_auto.py index 86ee620a15a8..afa2d8786767 100644 --- a/paddlenlp/transformers/llama/modeling_3D_auto.py +++ b/paddlenlp/transformers/llama/modeling_3D_auto.py @@ -56,7 +56,6 @@ apply_rotary_pos_emb, build_alibi_tensor, get_triangle_upper_mask, - is_casual_mask, repeat_kv, rms_norm_fused, ) @@ -94,9 +93,6 @@ def _make_causal_mask(input_ids_shape, past_key_values_length): return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) -attention_cnt = 0 - - def scaled_dot_product_attention( query_states, config, @@ -143,19 +139,9 @@ def scaled_dot_product_attention( # merge with the next tranpose key_states = paddle.transpose(key_states, [0, 2, 1, 3]) value_states = paddle.transpose(value_states, [0, 2, 1, 3]) - global attention_cnt - """ - if attention_cnt == 0: - print(f"q_{attention_cnt} shape: {query_states.shape} md5: {query_states._md5sum()}") - """ + # matmul and devide by sqrt(head_dim) attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) - """ - if attention_cnt == 0: - print( - f"attn_weights_{attention_cnt} shape: {attn_weights.shape} local_shape: {attn_weights._local_shape} md5sum: {attn_weights._md5sum()}" - ) - """ # then add alibi bias if alibi is not None: alibi = alibi.reshape([bsz, num_heads, 1, -1]) @@ -179,26 +165,13 @@ def scaled_dot_product_attention( f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" ) attn_weights = attn_weights + attention_mask - """ - if attention_cnt == 0: - print( - f"attn_weights_after_add_{attention_cnt} shape: {attn_weights.shape} local_shape: {attn_weights._local_shape} md5: {attn_weights._md5sum()}" - ) - """ if not paddle.in_dynamic_mode(): attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) else: attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) - """ - if attention_cnt == 0: - print( - f"attn_weights_after_soft_{attention_cnt} shape: {attn_weights.shape} local_shape: {attn_weights._local_shape} md5: {attn_weights._md5sum()}" - ) - """ attn_output = paddle.matmul(attn_weights, value_states) attn_output = attn_output.transpose([0, 2, 1, 3]) attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) - attention_cnt = attention_cnt + 1 return (attn_output, attn_weights) if output_attentions else attn_output @@ -216,9 +189,7 @@ def __init__(self, config): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - tmp = rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) - print(f"rms {tmp.placements}") - return tmp + return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -244,39 +215,32 @@ def __init__(self, config, ipp: Optional[int] = None): if config.fuse_attention_ffn: self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) - """ self.gate_up_fused_proj.weight = dist.shard_tensor( self.gate_up_fused_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ else: self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - """ self.gate_proj.weight = dist.shard_tensor( self.gate_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - """ self.up_proj.weight = dist.shard_tensor( self.up_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) - """ self.down_proj.weight = dist.shard_tensor( self.down_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)], ) - """ def forward(self, x): if self.fuse_attention_ffn: @@ -334,65 +298,56 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: 3 * self.hidden_size, bias_attr=False, ) - """ self.qkv_proj.weight = dist.shard_tensor( self.qkv_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ + else: self.q_proj = nn.Linear( self.hidden_size, self.hidden_size, bias_attr=False, ) - """ self.q_proj.weight = dist.shard_tensor( self.q_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ self.k_proj = nn.Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - """ self.k_proj.weight = dist.shard_tensor( self.k_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ self.v_proj = nn.Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - """ self.v_proj.weight = dist.shard_tensor( self.v_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)], ) - """ self.o_proj = nn.Linear( self.hidden_size, self.hidden_size, bias_attr=False, ) - """ self.o_proj.weight = dist.shard_tensor( self.o_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)], ) - """ if config.rope: self._init_rope() @@ -438,7 +393,6 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) - # print(f"attention input md5sum {hidden_states._md5sum()}") if self.fuse_attention_qkv: target_shape = [0, 0, self.num_heads, 3 * self.head_dim] mix_layer = self.qkv_proj(hidden_states) @@ -590,10 +544,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - """ - if self.idx == 0: - print(f"input_layernorm_{self.idx} shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}") - """ + # Self Attention has_gradient = not hidden_states.stop_gradient if ( @@ -635,23 +586,13 @@ def forward( present_key_value = outputs[2 if output_attentions else 1] hidden_states = residual + hidden_states - """ - if self.idx == 0: - print(f"att_{self.idx} shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}") - """ + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - """ - if self.idx == 0: - print( - f"post_attention_layernorm_{self.idx} shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}" - ) - """ hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - # md5 = hidden_states._md5sum() - # print(f"decoder_{self.idx} shape: {hidden_states.shape} md5sum: {md5}") + outputs = (hidden_states,) if output_attentions: @@ -818,13 +759,11 @@ def __init__(self, config: LlamaConfig): self.vocab_size, self.hidden_size, ) - """ self.embed_tokens.weight = dist.shard_tensor( self.embed_tokens.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)], ) - """ def get_layer_ipp(layer_index): mesh = fleet.auto.get_mesh() @@ -865,7 +804,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values # NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel combined_attention_mask = dist.shard_tensor( - combined_attention_mask, get_mesh(), [dist.Shard(0), dist.Replicate()] + combined_attention_mask, get_mesh(), [dist.Replicate(), dist.Replicate()] ) expanded_attn_mask = expanded_attn_mask & combined_attention_mask @@ -923,7 +862,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # print(f"inputs_embeds: {inputs_embeds.shape} md5sum: {inputs_embeds._md5sum()}") # embed positions if attention_mask is None: @@ -939,15 +877,16 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) # NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel - position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) + position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()]) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] if self.config.use_flash_attention: - is_casual = is_casual_mask(attention_mask) - if is_casual and alibi is None: - attention_mask = None + # attention_mask in flash_attn is always None for pretrain + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + hidden_states = inputs_embeds hidden_states = dist.reshard(hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()]) @@ -1039,9 +978,6 @@ def forward( ) -loss_cnt = 0 - - class LlamaPretrainingCriterionAuto(paddle.nn.Layer): """ Criterion for Llama. @@ -1057,22 +993,19 @@ def __init__(self, config): self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) def forward(self, prediction_scores, masked_lm_labels): - global loss_cnt if self.enable_parallel_cross_entropy: if prediction_scores.shape[-1] == self.config.vocab_size: warnings.warn( f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" ) self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) - # print(f"prediction_scores_{loss_cnt}: {prediction_scores.shape} md5sum: {prediction_scores._md5sum()}") + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) - # print(f"masked_lm_loss_{loss_cnt}: {masked_lm_loss.shape} md5sum: {masked_lm_loss._md5sum()}") # skip ignore_index which loss == 0 # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32") # TODO: solve the issue of conditional block masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") loss = paddle.mean(masked_lm_loss) - loss_cnt = loss_cnt + 1 return loss @@ -1085,7 +1018,6 @@ def __init__(self, config: LlamaConfig): shape=[config.hidden_size, vocab_size], dtype=paddle.get_default_dtype(), ) - """ self.weight = dist.shard_tensor( self.create_parameter( shape=[config.hidden_size, vocab_size], @@ -1094,14 +1026,11 @@ def __init__(self, config: LlamaConfig): get_mesh(-1), [dist.Replicate(), dist.Shard(1)], ) - """ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - # print(f"llamaout shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}") logits = paddle.matmul(hidden_states, self.weight, transpose_y=False) - # print(f"logit {logits.dist_attr}") return logits @@ -1209,9 +1138,7 @@ def forward( return_dict=None, ): input_ids.stop_gradient = True - - if not input_ids.is_dist(): - input_ids = dist.shard_tensor(input_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) + input_ids = dist.shard_tensor(input_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1242,8 +1169,7 @@ def forward( loss = None if labels is not None: labels.stop_gradient = True - if not labels.is_dist(): - labels = dist.shard_tensor(labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) + labels = dist.shard_tensor(labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) loss = self.criterion(logits, labels) if not return_dict: From 6a381c31731a12c71b1fd495a264ed4262feedfb Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Tue, 23 Jan 2024 09:49:03 +0000 Subject: [PATCH 2/7] support shard_dataloader in dynamic semi-auto --- .../auto_parallel/run_pretrain_3D_auto.py | 11 +- paddlenlp/trainer/auto_trainer.py | 40 ++-- paddlenlp/trainer/trainer.py | 220 ++++++++---------- paddlenlp/trainer/training_args.py | 16 +- .../transformers/llama/modeling_3D_auto.py | 163 ++++++++----- scripts/distribute/ci_case_auto.sh | 1 + 6 files changed, 232 insertions(+), 219 deletions(-) diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.py b/llm/llama/auto_parallel/run_pretrain_3D_auto.py index 1fa545555b74..8ae34de2b9d2 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.py @@ -366,13 +366,10 @@ class PretrainingTrainer(SemiAutoTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _wrap_dist_loader(self, train_dataloader): - return dist.shard_dataloader( - dataloader=train_dataloader, - meshes=self._get_meshes_for_loader(), - input_keys=["input_ids", "labels"], - shard_dims="dp", - ) + def get_train_dataloader(self): + dist_loader = super().get_train_dataloader() + dist_loader._input_keys = ["input_ids", "labels"] + return dist_loader def print_config(args, key=""): diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 3a16f9c7b3f6..9f1693649b47 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -40,6 +40,8 @@ def loss_func(loss, outputs): super().__init__(*args, **kwargs) assert self.args.use_auto_parallel + self.global_mesh = fleet.auto.get_mesh() + def _nested_gather(self, tensors): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before @@ -54,31 +56,16 @@ def _wrap_model(self, model, training=True): def _get_meshes_for_loader(self): def _get_mesh(pp_idx=0): - mesh = fleet.auto.get_mesh() - if "pp" in mesh.dim_names: - mesh = mesh.get_mesh_with_dim("pp")[pp_idx] - return mesh + return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] meshes = [] for pp_idx in range(self.args.pipeline_parallel_degree): meshes.append(_get_mesh(pp_idx)) return meshes - def _wrap_dist_loader(self, train_dataloader): - return dist.shard_dataloader( - dataloader=train_dataloader, - meshes=self._get_meshes_for_loader(), - shard_dims="dp", - ) - def _wrap_for_static(self, model, train_dataloader): - # TODO: convert fleet.auto.Strategy to dist.Strategy - # TODO: fix bugs in paddle/distributed/auto_parallel/api.py#L981 about sample_split of engine._prepare_data_spec - - dist_loader = self._wrap_dist_loader(train_dataloader) - - model = dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy) - return model, dist_loader + model = dist.to_static(model, train_dataloader, self.criterion, self.optimizer, strategy=self.args.strategy) + return model def _wrap_for_amp_training(self): pass @@ -125,14 +112,15 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: drop_last=self.args.dataloader_drop_last, ) - # return DistributedBatchSampler( - # self.train_dataset, - # batch_size=self.args.per_device_train_batch_size, - # shuffle=True, - # num_replicas=self.args.dataset_world_size, - # rank=self.args.dataset_rank, - # drop_last=self.args.dataloader_drop_last, - # ) + def get_train_dataloader(self): + train_dataloader = super().get_train_dataloader() + dist_loader = dist.shard_dataloader( + dataloader=train_dataloader, + meshes=self._get_meshes_for_loader(), + shard_dims="dp", + ) + + return dist_loader def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: model.train() diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 48e4caa9aa02..13982d33ee16 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -131,7 +131,6 @@ ShardingOption, TrainerMemoryTracker, TrainOutput, - _exec_mode_guard, find_batch_size, get_last_checkpoint, get_scheduler, @@ -720,11 +719,6 @@ def train( self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) - if self.args.use_auto_parallel and self.args.run_static_semi_auto: - model, train_dataloader = self._wrap_for_static(model, train_dataloader) - - self.model = model - logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs}") @@ -735,7 +729,24 @@ def train( logger.info(f" Total num train samples = {num_train_samples:,}") # per_device_trainable_numel = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) # TODO: Temporary fix since Tensor.numel() not supported in distributed mode - self._print_trainable_numel() + per_device_trainable_numel = sum(np.prod(p.shape) for p in model.parameters() if not p.stop_gradient) + logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + if self.args.use_hybrid_parallel: + # todo fix for pipeline_parallel_degree + parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) + if parts_num > 1: + all_reduce_dtype = "int64" + if paddle.get_device().split(":")[0] in ["npu", "xpu"]: + # TODO(duanyanhui): fix when NPU all_reduce supports int64 + all_reduce_dtype = "float32" + trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) + paddle.distributed.all_reduce(trainable_numel_tensor) + trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size + if self.args.sep_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.sep_parallel_degree + # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited + # so, the trainable numel is a little bigger than real. + logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") start_time = time.time() self._globalstep_last_start_time = time.time() @@ -818,8 +829,7 @@ def train( self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - with _exec_mode_guard("dynamic"): - tr_loss = paddle.to_tensor(0.0) + tr_loss = paddle.to_tensor(0.0) self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step @@ -880,7 +890,7 @@ def train( forbidden_no_sync = False # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync - if self.args.use_hybrid_parallel or self.args.use_auto_parallel: + if self.args.use_hybrid_parallel: forbidden_no_sync = True availiable_no_sync = dp_enabled and not forbidden_no_sync @@ -907,20 +917,12 @@ def train( else: tr_loss_step = self.training_step(model, inputs) - with _exec_mode_guard("dynamic"): - tr_loss += tr_loss_step - - disable_accumulation = ( - self.args.use_auto_parallel - and self.args.pipeline_parallel_degree > 1 - and self.args.run_static_semi_auto - ) + tr_loss += tr_loss_step if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - or disable_accumulation ): if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss(): tr_loss /= self.args.gradient_accumulation_steps @@ -933,20 +935,81 @@ def train( # local_rank != -1 don't means dp in networks. self.timers and self.timers("all-reduce").start() - self.synchronize_gradients(availiable_no_sync, dp_master_grad) + # Case 1: Use recompute and dp / sharding stage1, + # manualy collect gradient for dp. + if args.recompute and availiable_no_sync: + fused_allreduce_gradients(list(model.parameters()), None) + + # Case 2: hack dp with master_grad + if dp_master_grad and not (args.recompute and availiable_no_sync): + fused_allreduce_gradients(list(model.parameters()), None) + + # Pipeline parallel mode, handle gradient reduce here to overlap + pipeline_parallel_config = ( + set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() + ) + enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config + enable_release_grads = "enable_release_grads" in pipeline_parallel_config + + # Case 3: Pipeline parallel mode, overlap with dp + if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: + parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) + + if not enable_dp_comm_overlap: + if self.optimizer._sharding_enable: + assert reshard_util.is_sharding_opt(self.optimizer) + self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) + + if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): + fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) self.timers and self.timers("all-reduce").stop() self.timers and self.timers("optimizer-step").start() + if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss(): + for p in model._layers.parameters(): + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) + elif p.grad is not None: + p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) + # Optimizer step self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None ) - - self.optimizer_step() + optimizer_was_run = True + if self.do_grad_scaling: + scale_before = paddle.assign(self.scaler._scale) + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler._scale + optimizer_was_run = not self.scaler._cache_founf_inf + if not optimizer_was_run: + scale_before_value = scale_before.cpu().numpy() + scale_after_value = scale_after.cpu().numpy() + logger.warning( + f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" + ) + elif isinstance(self.optimizer, HybridParallelOptimizer): + self.optimizer._step(parameters_list) + else: + self.optimizer.step() self.timers and self.timers("optimizer-step").stop() + if optimizer_was_run: + self.lr_scheduler.step() + + if enable_release_grads and args.pipeline_parallel_degree > 1: + self.optimizer.clear_grad(set_to_zero=False) + for _, buffers in model._chunk_2_comm_buffers.items(): + for buffer in buffers: + buffer._clear_grad_storage() + else: + self.optimizer.clear_grad() + self.callback_handler.on_optimizer_end( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None ) @@ -1009,7 +1072,7 @@ def train( "on multiple nodes, you should activate `--save_on_each_node`." ) - self._total_loss_scalar += self._get_item_from_loss(tr_loss) + self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) @@ -1026,98 +1089,6 @@ def train( return TrainOutput(self.state.global_step, train_loss, metrics) - def _wrap_for_static(self, model, train_dataloader): - pass - - def _print_trainable_numel(self): - per_device_trainable_numel = sum(np.prod(p.shape) for p in self.model.parameters() if not p.stop_gradient) - logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") - - if self.args.use_hybrid_parallel: - # todo fix for pipeline_parallel_degree - parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) - if parts_num > 1: - all_reduce_dtype = "int64" - if paddle.get_device().split(":")[0] in ["npu", "xpu"]: - # TODO(duanyanhui): fix when NPU all_reduce supports int64 - all_reduce_dtype = "float32" - trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) - paddle.distributed.all_reduce(trainable_numel_tensor) - trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size - if self.args.sep_parallel_degree > 0: - trainable_numel = trainable_numel // self.args.sep_parallel_degree - # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited - # so, the trainable numel is a little bigger than real. - logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") - - def synchronize_gradients(self, availiable_no_sync, dp_master_grad): - # Case 1: Use recompute and dp / sharding stage1, - # manualy collect gradient for dp. - if self.args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(self.model.parameters()), None) - - # Case 2: hack dp with master_grad - if dp_master_grad and not (self.args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(self.model.parameters()), None) - - # Pipeline parallel mode, handle gradient reduce here to overlap - pipeline_parallel_config = ( - set(self.args.pipeline_parallel_config.split(" ")) if self.args.pipeline_parallel_degree > 1 else set() - ) - enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config - - # Case 3: Pipeline parallel mode, overlap with dp - if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: - parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) - - if not enable_dp_comm_overlap: - if self.optimizer._sharding_enable: - assert reshard_util.is_sharding_opt(self.optimizer) - self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) - - if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) - - def optimizer_step(self): - if self.args.pipeline_parallel_degree > 1 and self._enable_delay_scale_loss(): - for p in self.model._layers.parameters(): - with paddle.no_grad(): - if hasattr(p, "main_grad") and p.main_grad is not None: - assert p.grad is None - p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) - elif p.grad is not None: - p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) - - optimizer_was_run = True - if self.do_grad_scaling: - scale_before = paddle.assign(self.scaler._scale) - self.scaler.step(self.optimizer) - self.scaler.update() - scale_after = self.scaler._scale - optimizer_was_run = not self.scaler._cache_founf_inf - if not optimizer_was_run: - scale_before_value = scale_before.cpu().numpy() - scale_after_value = scale_after.cpu().numpy() - logger.warning( - f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" - ) - elif isinstance(self.optimizer, HybridParallelOptimizer): - parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt) - self.optimizer._step(parameters_list) - else: - self.optimizer.step() - - if optimizer_was_run: - self.lr_scheduler.step() - - if self.args.pipeline_parallel_degree > 1 and "enable_release_grads" in self.args.pipeline_parallel_config: - self.optimizer.clear_grad(set_to_zero=False) - for _, buffers in self.model._chunk_2_comm_buffers.items(): - for buffer in buffers: - buffer._clear_grad_storage() - else: - self.optimizer.clear_grad() - def _load_best_model_from_peft_checkpoint(self): convert_tp = False if isinstance(self.model, LoRAModel): @@ -1199,22 +1170,14 @@ def _print_timer(self): if timer_info or paddle_timer_info: logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}") - def _get_item_from_loss(self, loss): - if isinstance(loss, paddle.Tensor): - if loss.is_dist(): - return loss._local_value().item() if loss._is_initialized() else 0.0 - else: - return loss.item() if loss._is_initialized() else 0.0 - else: - return loss - def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): if self.control.should_log: logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes - tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean()) + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + # reset tr_loss to zero tr_loss.subtract_(tr_loss) @@ -2275,6 +2238,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ self.model.save_pretrained( output_dir, variant=self.args.weight_name_suffix, + save_function=self._save_ckpt_func, merge_tensor_parallel=merge_tensor_parallel, is_main_process=self.args.should_save, max_shard_size="1024GB", @@ -2293,6 +2257,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ config_to_save=config_to_save, merge_tensor_parallel=merge_tensor_parallel, variant=weight_name_suffix, + save_function=self._save_ckpt_func, is_main_process=self.args.should_save, max_shard_size="1024GB", ) @@ -2301,6 +2266,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ output_dir, merge_tensor_parallel=merge_tensor_parallel, variant=self.args.weight_name_suffix, + save_function=self._save_ckpt_func, is_main_process=self.args.should_save, max_shard_size="1024GB", ) @@ -2327,6 +2293,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ config_to_save=config_to_save, merge_tensor_parallel=merge_tensor_parallel, variant=weight_name_suffix, + save_function=self._save_ckpt_func, is_main_process=self.args.should_save, max_shard_size="1024GB", ) @@ -2335,6 +2302,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ output_dir, merge_tensor_parallel=merge_tensor_parallel, variant=self.args.weight_name_suffix, + save_function=self._save_ckpt_func, is_main_process=self.args.should_save, max_shard_size="1024GB", ) @@ -2685,7 +2653,7 @@ def evaluation_loop( metrics = {} if all_losses is not None: - metrics[f"{metric_key_prefix}_loss"] = self._get_item_from_loss(all_losses.mean()) + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 0192301c2acc..601bc549cb9d 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -27,6 +27,7 @@ from typing import Any, Dict, List, Optional import paddle +import paddle.distributed as dist from paddle.distributed import fleet from ..utils.log import logger @@ -1278,10 +1279,7 @@ def is_segment_parallel_supported(): self.strategy = strategy order = ["dp", "pp", "mp"] degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree] - mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree)))) - if not mesh_dims: - mesh_dims = [("dp", 1)] - + mesh_dims = list(zip(order, degree)) fleet.auto.create_mesh(mesh_dims) else: world_size = paddle.distributed.get_world_size() @@ -1400,7 +1398,7 @@ def data_parallel_rank(self): return dp_group.rank elif self.use_auto_parallel: mesh = fleet.auto.get_mesh() - return mesh.get_dim_size("dp") + return mesh.get_rank_by_dim_and_process_id("dp", dist.get_rank()) else: return paddle.distributed.get_rank() @@ -1437,6 +1435,9 @@ def tensor_parallel_rank(self): hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() return max(tp_group.rank, 0) + elif self.use_auto_parallel: + mesh = fleet.auto.get_mesh() + return mesh.get_rank_by_dim_and_process_id("mp", dist.get_rank()) else: return 0 @@ -1446,6 +1447,9 @@ def pipeline_parallel_rank(self): hcg = fleet.get_hybrid_communicate_group() rank = hcg.get_stage_id() return max(rank, 0) + elif self.use_auto_parallel: + mesh = fleet.auto.get_mesh() + return mesh.get_rank_by_dim_and_process_id("pp", dist.get_rank()) else: return 0 @@ -1588,6 +1592,8 @@ def should_save_model_state(self): else: if self.should_save_sharding_stage1_model: return True + elif self.use_auto_parallel: + return True elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0 diff --git a/paddlenlp/transformers/llama/modeling_3D_auto.py b/paddlenlp/transformers/llama/modeling_3D_auto.py index afa2d8786767..62da0545ae2d 100644 --- a/paddlenlp/transformers/llama/modeling_3D_auto.py +++ b/paddlenlp/transformers/llama/modeling_3D_auto.py @@ -53,6 +53,7 @@ LlamaNTKScalingRotaryEmbedding, LlamaRotaryEmbedding, _expand_2d_mask, + _make_causal_mask, apply_rotary_pos_emb, build_alibi_tensor, get_triangle_upper_mask, @@ -77,22 +78,6 @@ def get_mesh(pp_idx=0): return mesh -def _make_causal_mask(input_ids_shape, past_key_values_length): - """ - Make causal mask used for self-attention - """ - batch_size, target_length = input_ids_shape # target_length: seq_len - - mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) - - if past_key_values_length > 0: - # [tgt_len, tgt_len + past_len] - mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) - - # [bs, 1, tgt_len, tgt_len + past_len] - return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) - - def scaled_dot_product_attention( query_states, config, @@ -106,14 +91,12 @@ def scaled_dot_product_attention( _, kv_seq_len, _, _ = value_states.shape if config.use_flash_attention and flash_attention: - # Flash Attention now ignore attention mask - # Current Flash Attention doesn't support attn maskt # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] - if alibi is not None: - attention_mask = attention_mask.cast(alibi.dtype) + alibi version = paddle.version.full_version if version != "0.0.0" and version <= "2.5.2": + if alibi is not None: + raise ValueError("Flash Attention doesn't support alibi") attn_output, attn_weights = flash_attention( query_states, key_states, @@ -122,6 +105,9 @@ def scaled_dot_product_attention( return_softmax=output_attentions, ) else: + if alibi is not None: + alibi = alibi.reshape([bsz, num_heads, 1, -1]) + attention_mask = attention_mask.cast(alibi.dtype) + alibi attn_output = F.scaled_dot_product_attention( query_states, key_states, @@ -164,11 +150,10 @@ def scaled_dot_product_attention( raise ValueError( f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" ) + attn_weights = attn_weights + attention_mask - if not paddle.in_dynamic_mode(): - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) - else: - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + attn_output = paddle.matmul(attn_weights, value_states) attn_output = attn_output.transpose([0, 2, 1, 3]) attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) @@ -393,6 +378,14 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + # enter tp region + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + if self.fuse_attention_qkv: target_shape = [0, 0, self.num_heads, 3 * self.head_dim] mix_layer = self.qkv_proj(hidden_states) @@ -406,6 +399,11 @@ def forward( key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape) value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape) + if self.config.sequence_parallel: + query_states = paddle.transpose(query_states, [1, 0, 2, 3]) + key_states = paddle.transpose(key_states, [1, 0, 2, 3]) + value_states = paddle.transpose(value_states, [1, 0, 2, 3]) + kv_seq_len = key_states.shape[-3] if past_key_value is not None: @@ -453,7 +451,8 @@ def forward( and has_gradient and self.recompute_granularity == "core_attn" ): - outputs = recompute(scaled_dot_product_attention)( + outputs = recompute( + scaled_dot_product_attention, query_states, self.config, key_states, @@ -482,6 +481,14 @@ def forward( # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. attn_output = self.o_proj(attn_output) + # enter sp region + if self.config.sequence_parallel: + attn_output = paddle.transpose(attn_output, [1, 0, 2]) + attn_output = dist.reshard( + attn_output, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) if not output_attentions: attn_weights = None @@ -500,7 +507,7 @@ def forward( class LlamaDecoderLayerAuto(nn.Layer): - def __init__(self, config, layerwise_recompute: bool = False, ipp: Optional[int] = None, idx=None): + def __init__(self, config, layerwise_recompute: bool = False, ipp: Optional[int] = None): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -514,7 +521,6 @@ def __init__(self, config, layerwise_recompute: bool = False, ipp: Optional[int] self.layerwise_recompute = layerwise_recompute self.recompute_granularity = config.recompute_granularity self.ipp = ipp - self.idx = idx def forward( self, @@ -553,7 +559,8 @@ def forward( and has_gradient and self.recompute_granularity == "full_attn" ): - outputs = recompute(self.self_attn)( + outputs = recompute( + self.self_attn, hidden_states, position_ids, past_key_value, @@ -590,7 +597,25 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + + # enter tp region + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + hidden_states = self.mlp(hidden_states) + + # enter sp region + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -759,6 +784,7 @@ def __init__(self, config: LlamaConfig): self.vocab_size, self.hidden_size, ) + self.embed_tokens.weight = dist.shard_tensor( self.embed_tokens.weight, get_mesh(), @@ -776,7 +802,7 @@ def get_layer_ipp(layer_index): self.layers = nn.LayerList( [ - LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i), i) + LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i)) for i in range(config.num_hidden_layers) ] ) @@ -784,6 +810,10 @@ def get_layer_ipp(layer_index): self.gradient_checkpointing = False + self.placements = ( + [dist.Shard(1), dist.Shard(0)] if self.config.sequence_parallel else [dist.Shard(0), dist.Replicate()] + ) + def get_input_embeddings(self): return self.embed_tokens @@ -802,9 +832,10 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values input_shape, past_key_values_length=past_key_values_length ) # NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel - combined_attention_mask = dist.shard_tensor( - combined_attention_mask, get_mesh(), [dist.Replicate(), dist.Replicate()] + combined_attention_mask, + get_mesh(), + [dist.Replicate(), dist.Replicate()], ) expanded_attn_mask = expanded_attn_mask & combined_attention_mask @@ -879,6 +910,10 @@ def forward( # NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()]) + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + if self.config.use_flash_attention: # attention_mask in flash_attn is always None for pretrain attention_mask = None @@ -888,14 +923,14 @@ def forward( ) # [bs, 1, seq_len, seq_len] hidden_states = inputs_embeds - hidden_states = dist.reshard(hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()]) + hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - pre_ipp = 0 + pre_ipp = None for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -907,17 +942,21 @@ def forward( hidden_states = dist.reshard( hidden_states, get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], + self.placements, ) position_ids = dist.reshard( position_ids, get_mesh(decoder_layer.ipp), [dist.Shard(0), dist.Replicate()], ) - attention_mask = dist.reshard( - attention_mask, - get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], + attention_mask = ( + dist.reshard( + attention_mask, + get_mesh(decoder_layer.ipp), + [dist.Shard(0), dist.Replicate()], + ) + if attention_mask is not None + else attention_mask ) if ( @@ -926,7 +965,8 @@ def forward( and has_gradient and self.recompute_granularity == "full" ): - layer_outputs = recompute(decoder_layer)( + layer_outputs = recompute( + decoder_layer, hidden_states, position_ids, attention_mask, @@ -1000,10 +1040,26 @@ def forward(self, prediction_scores, masked_lm_labels): ) self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) - masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) - # skip ignore_index which loss == 0 - # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32") - # TODO: solve the issue of conditional block + # Force Replicated to match dy & st + prediction_scores1 = dist.reshard( + prediction_scores, + get_mesh(-1), + [dist.Replicate(), dist.Replicate()], + ) + masked_lm_labels1 = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()]) + + # Force entropy same kernel + if isinstance(prediction_scores1, paddle.Tensor): + masked_lm_loss = self.loss_func( + prediction_scores1.astype("float32")._use_gpudnn(False), + masked_lm_labels1.unsqueeze(2), + ) + else: + masked_lm_loss = self.loss_func( + prediction_scores1.astype("float32"), + masked_lm_labels1.unsqueeze(2), + ) + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") loss = paddle.mean(masked_lm_loss) return loss @@ -1019,10 +1075,7 @@ def __init__(self, config: LlamaConfig): dtype=paddle.get_default_dtype(), ) self.weight = dist.shard_tensor( - self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - ), + self.weight, get_mesh(-1), [dist.Replicate(), dist.Shard(1)], ) @@ -1041,12 +1094,6 @@ def __init__(self, config): super().__init__(config) self.config = config - # dygraph auto_parallel do not support lazy now! - # with paddle.LazyGuard(): - # self.llama = LlamaModelAuto(config) - # self.lm_head = LlamaLMHeadAuto(config) - # self.criterion = LlamaPretrainingCriterionAuto(config) - self.llama = LlamaModelAuto(config) self.lm_head = LlamaLMHeadAuto(config) self.criterion = LlamaPretrainingCriterionAuto(config) @@ -1138,7 +1185,6 @@ def forward( return_dict=None, ): input_ids.stop_gradient = True - input_ids = dist.shard_tensor(input_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1157,6 +1203,14 @@ def forward( ) hidden_states = outputs[0] # [bs, seq_len, dim] + # enter tp region + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Shard(1), dist.Replicate()], + ) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # if labels is None,means we need full output, instead of tensor_parallel_output # tensor_parallel_output is togather with ParallelCrossEntropy @@ -1169,7 +1223,6 @@ def forward( loss = None if labels is not None: labels.stop_gradient = True - labels = dist.shard_tensor(labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) loss = self.criterion(logits, labels) if not return_dict: diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 52e15505c259..f46b20e35a1e 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1221,6 +1221,7 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { --device "gpu" \ --data_impl "mmap" \ --parallel_mode "auto" \ + --run_static_semi_auto 0 \ --max_grad_norm 1.0 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.2 | grep 'global_step 10' | awk -F '; loss' '{print $2}' | awk -F 'lr' '{print $1}'` From b3e64c362e3adc4d72bbae303092d8f891b6fdf3 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 25 Jan 2024 07:38:37 +0000 Subject: [PATCH 3/7] rewrite traning loop --- .../auto_parallel/run_pretrain_3D_auto.py | 13 +- paddlenlp/trainer/auto_trainer.py | 543 ++++++++++++++++-- paddlenlp/trainer/trainer.py | 90 ++- paddlenlp/trainer/training_args.py | 4 +- .../transformers/llama/modeling_3D_auto.py | 52 +- 5 files changed, 585 insertions(+), 117 deletions(-) diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.py b/llm/llama/auto_parallel/run_pretrain_3D_auto.py index 8ae34de2b9d2..730dba2d66aa 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.py @@ -37,11 +37,12 @@ LinearAnnealingWithWarmupDecay, LlamaConfig, LlamaForCausalLM3DAuto, + LlamaPretrainingCriterion3DAuto, ) from paddlenlp.utils.log import logger MODEL_CLASSES = { - "llama": (LlamaConfig, LlamaForCausalLM3DAuto), + "llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto), } @@ -487,7 +488,7 @@ def main(): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) - config_class, model_class = MODEL_CLASSES[model_args.model_type] + config_class, model_class, criterion_class = MODEL_CLASSES[model_args.model_type] tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) @@ -551,8 +552,7 @@ def main(): with paddle.LazyGuard(): model = model_class.from_config(config, dtype=dtype) - - criterion = None + criterion = criterion_class(config) for param in model.parameters(): assert not param._is_initialized() @@ -598,11 +598,6 @@ def fn(layer): need_data=training_args.should_load_dataset, ) - # total_train_batch_size_per_acc_step = ( - # training_args.per_device_train_batch_size * training_args.data_parallel_degree - # ) - # total_train_batch_size = total_train_batch_size_per_acc_step * training_args.gradient_accumulation_steps - trainer = PretrainingTrainer( model=model, criterion=criterion, diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 9f1693649b47..17c56303d1d0 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -12,18 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +import os +import random +import time +from typing import Any, Dict, List, Optional, Union import numpy as np import paddle import paddle.distributed as dist import paddle.nn as nn from paddle.distributed import fleet +from paddle.io import DistributedBatchSampler +from tqdm.auto import tqdm from paddlenlp.trainer import Trainer +from ..transformers.segment_parallel_utils import split_inputs_sequence_dim +from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler from ..utils.log import logger -from .trainer_utils import _exec_mode_guard, has_length +from .argparser import strtobool +from .trainer_callback import DefaultFlowCallback, ProgressCallback, TrainerState +from .trainer_utils import ( # set_hyrbid_parallel_seed, + PREFIX_CHECKPOINT_DIR, + TrainOutput, + _exec_mode_guard, + has_length, + speed_metrics, +) +from .utils.helper import distributed_file, distributed_isfile # nested_truncate, + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" + +MODEL_NAME = "model" +OPTIMIZER_NAME = "optimizer" +DIST_CKPT_NAME = "dist_ckpt" +SCHEDULER_NAME = "scheduler.pdparams" +SCALER_NAME = "scaler.pdparams" class SemiAutoTrainer(Trainer): @@ -42,6 +71,9 @@ def loss_func(loss, outputs): self.global_mesh = fleet.auto.get_mesh() + self.mesh_in_dp = self.global_mesh.get_mesh_with_dim("dp")[self.args.data_parallel_rank] + self.comm_group_in_dp = dist.new_group(list(self.mesh_in_dp.process_ids)) + def _nested_gather(self, tensors): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before @@ -49,11 +81,6 @@ def _nested_gather(self, tensors): """ return tensors - def _wrap_model(self, model, training=True): - self.optimizer = dist.shard_optimizer(self.optimizer) if not self.args.run_static_semi_auto else self.optimizer - - return model - def _get_meshes_for_loader(self): def _get_mesh(pp_idx=0): return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] @@ -63,39 +90,333 @@ def _get_mesh(pp_idx=0): meshes.append(_get_mesh(pp_idx)) return meshes - def _wrap_for_static(self, model, train_dataloader): - model = dist.to_static(model, train_dataloader, self.criterion, self.optimizer, strategy=self.args.strategy) - return model + def _wrap_for_auto(self, model, train_dataloader): + if self.args.run_static_semi_auto: + return dist.to_static(model, train_dataloader, self.criterion, self.optimizer, strategy=self.args.strategy) + else: + self.optimizer = dist.shard_optimizer(self.optimizer) + return model def _wrap_for_amp_training(self): pass - def _print_trainable_numel(self): - if not self.args.run_static_semi_auto: - super()._print_trainable_numel() + def _get_item_from_loss(self, loss): + if isinstance(loss, paddle.Tensor): + if loss.is_dist(): + return loss._local_value().item() if loss._is_initialized() else 0.0 + else: + return loss.item() if loss._is_initialized() else 0.0 else: - per_device_trainable_numel = sum( - np.prod(p.shape) for p in self.model._engine._model.parameters() if not p.stop_gradient + return loss + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + """ + args = self.args + self.is_in_train = True + + self._sync_resume_states(resume_from_checkpoint) + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + if not self.args.should_load_sharding_stage1_model: + self._load_from_checkpoint(resume_from_checkpoint) + + train_dataloader = self.get_train_dataloader() + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size + ( + len_dataloader, + max_steps, + num_train_epochs, + num_update_steps_per_epoch, + num_examples, + num_train_samples, + ) = self._get_train_steps_and_samples(args, train_dataloader, total_train_batch_size) + + delay_optimizer_creation = False + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState() + + model = self._wrap_for_auto(self.model_wrapped, train_dataloader) + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + self._print_trainable_numel() + + start_time = time.time() + self._globalstep_last_start_time = time.time() + self.state.epoch = 0 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if ( + resume_from_checkpoint is not None + and distributed_isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + and not self.args.ignore_load_lr_and_optim + ): + self.state = TrainerState.load_from_json( + distributed_file(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) ) - logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + if self.args.world_size > 1: + global_step_list = [] + paddle.distributed.all_gather( + global_step_list, paddle.to_tensor([self.state.global_step], dtype="int64") + ) + assert ( + paddle.sum(paddle.stack(global_step_list) - global_step_list[0]) == 0 + ), f"Error, get different globel step, please check! step list: {[x.item() for x in global_step_list]}" + + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " + "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " + "flag to your launch command, but you will resume the training on data already seen by your model." + ) + if self.is_local_process_zero() and not args.disable_tqdm: + steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) + steps_trained_progress_bar.set_description("Skipping the first batches") + if not args.ignore_data_skip: + if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( + train_dataloader.batch_sampler, NlpDistributedBatchSampler + ): + consumed_samples = ( + self.state.global_step + * args.train_batch_size + * args.gradient_accumulation_steps + * args.dataset_world_size + ) + train_dataloader.batch_sampler.set_epoch(consumed_samples=consumed_samples) + logger.info(f"Set DistributedBatchSampler consumed_samples to {consumed_samples}") + + epoch_iterator = train_dataloader + # steps_in_epoch = len(epoch_iterator) + steps_in_epoch = ( + len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps + ) + if len_dataloader is not None: + if self.args.gradient_accumulation_steps > len(epoch_iterator): + logger.warning( + f"changing accumulation step from `{self.args.gradient_accumulation_steps}` to `{len(epoch_iterator)}` to avoid, cross epoch accumulate" + ) + self.args.gradient_accumulation_steps = len(epoch_iterator) + + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + + self.state.max_steps = int(max_steps) + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() - parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) - if parts_num > 1: - all_reduce_dtype = "int64" - if paddle.get_device().split(":")[0] in ["npu", "xpu"]: - # TODO(duanyanhui): fix when NPU all_reduce supports int64 - all_reduce_dtype = "float32" + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + with _exec_mode_guard("dynamic"): + tr_loss = paddle.to_tensor(0.0) + + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + + if self.args.device == "npu" and self.args.flatten_param_grads: + from .plugins.npu_plugin import npu_accelerate_plugin + + npu_accelerate_plugin(self.optimizer) + + self.timers and self.timers("read-data").start() + + for epoch in range(epochs_trained, num_train_epochs): + if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( + train_dataloader.batch_sampler, DistributedBatchSampler + ): + train_dataloader.batch_sampler.set_epoch(epoch) + + step_control = 0 # used in loop control, reset to 0 after every step + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + for step, inputs in enumerate(epoch_iterator): + if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: + inputs = split_inputs_sequence_dim(inputs) + self.timers and self.timers("read-data").stop() + os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step) + self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) + + # Skip past any already trained steps if resuming training + # for paddlenlp.utils.batch_sampler.DistributedBatchSampler + # We use consumed_samples to reset the status + if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( + train_dataloader.batch_sampler, NlpDistributedBatchSampler + ): + if step == 0: + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(steps_trained_in_current_epoch) + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + self._load_rng_state(resume_from_checkpoint) + step += steps_trained_in_current_epoch + elif steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step_control % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + self.timers and self.timers("forward-backward").start() + + tr_loss_step = self.training_step(model, inputs) with _exec_mode_guard("dynamic"): - trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) - paddle.distributed.all_reduce(trainable_numel_tensor) - trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size + tr_loss += tr_loss_step + + disable_accumulation = self.args.pipeline_parallel_degree > 1 and self.args.run_static_semi_auto + + if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps + and (step + 1) == steps_in_epoch + or disable_accumulation + ): + if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss(): + tr_loss /= self.args.gradient_accumulation_steps + + self.timers and self.timers("forward-backward").stop() + + self.timers and self.timers("optimizer-step").start() + + # Optimizer step + self.callback_handler.on_optimizer_begin( + args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None + ) + + self.optimizer_step() + + self.timers and self.timers("optimizer-step").stop() + + self.callback_handler.on_optimizer_end( + args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None + ) + + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) + self._print_timer() + step_control = 0 + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + step_control += 1 + + if self.control.should_epoch_stop or self.control.should_training_stop: + break - if self.args.sep_parallel_degree > 0: - trainable_numel = trainable_numel // self.args.sep_parallel_degree - # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited - # so, the trainable numel is a little bigger than real. - logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") + self.timers and self.timers("read-data").start() + + if step < 0: + logger.warning( + f"There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({self.state.max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) + + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\nTraining completed. \n") + + self._total_loss_scalar += self._get_item_from_loss(tr_loss) + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def _print_trainable_numel(self): + if not self.args.run_static_semi_auto: + per_device_trainable_numel = sum(np.prod(p.shape) for p in self.model.parameters() if not p.stop_gradient) + else: + per_device_trainable_numel = sum( + np.prod(p.shape) for p in self.model._engine._model.parameters() if not p.stop_gradient + ) + logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + + parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) + if parts_num > 1: + all_reduce_dtype = "int64" + if paddle.get_device().split(":")[0] in ["npu", "xpu"]: + # TODO(duanyanhui): fix when NPU all_reduce supports int64 + all_reduce_dtype = "float32" + + with _exec_mode_guard("dynamic"): + trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype) + paddle.distributed.all_reduce(trainable_numel_tensor) + trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size + + if self.args.sep_parallel_degree > 0: + trainable_numel = trainable_numel // self.args.sep_parallel_degree + # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited + # so, the trainable numel is a little bigger than real. + logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): @@ -122,31 +443,38 @@ def get_train_dataloader(self): return dist_loader + def dynamic_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + with self.autocast_smart_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + else: + loss.backward() + + return loss + + def static_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + input_ids, labels = tuple(inputs.values()) + loss = model(input_ids, labels) + + if loss is not None and self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + return loss + def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: model.train() inputs = self._prepare_inputs(inputs) if not self.args.run_static_semi_auto: - with self.autocast_smart_context_manager(): - loss = self.compute_loss(model, inputs) - - if self.args.gradient_accumulation_steps > 1: - loss = loss / self.args.gradient_accumulation_steps - - if self.do_grad_scaling: - self.scaler.scale(loss).backward() - else: - loss.backward() + loss = self.dynamic_traning(model, inputs) else: - input_ids, labels = tuple(inputs.values()) - loss = model(input_ids, labels) - - if self.args.pipeline_parallel_degree > 1: - self._pp_data_buffer = {} - - if loss is not None and self.args.gradient_accumulation_steps > 1: - loss = loss / self.args.gradient_accumulation_steps + loss = self.static_traning(model, inputs) if isinstance(loss, paddle.Tensor): return loss.detach() if loss._is_initialized() else float(0.0) @@ -157,15 +485,126 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, else: return float(loss) - def synchronize_gradients(self, *args, **kwargs): - pass - def optimizer_step(self): if not self.args.run_static_semi_auto: - super().optimizer_step() + optimizer_was_run = True + if self.do_grad_scaling: + scale_before = paddle.assign(self.scaler._scale) + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler._scale + optimizer_was_run = not self.scaler._cache_founf_inf + if not optimizer_was_run: + scale_before_value = scale_before.cpu().numpy() + scale_after_value = scale_after.cpu().numpy() + logger.warning( + f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" + ) + else: + self.optimizer.step() + + if optimizer_was_run: + self.lr_scheduler.step() + + self.optimizer.clear_grad() else: pass def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): with _exec_mode_guard("dynamic"): super()._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, **kwargs) + + def _save_checkpoint(self, model, metrics=None): + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self.args.output_dir + output_dir = f"{run_dir}/{checkpoint_folder}" + + if self.args.should_save or self.args.should_save_model_state: + os.makedirs(output_dir, exist_ok=True) + + if self.args.should_save: + logger.info(f"Saving checkpoinit files into {output_dir}") + + if self.args.should_save_model_state: + + optim_state_dict = self.optimizer.state_dict() + optim_state_dict.pop("LR_Scheduler", None) + + state_dict = { + MODEL_NAME: self.model.state_dict(), + OPTIMIZER_NAME: optim_state_dict, + } + + self._save_ckpt_func(state_dict, os.path.join(output_dir, DIST_CKPT_NAME), self.comm_group_in_dp) + logger.info(f"Model weights and optimizer states saved in {output_dir}/{DIST_CKPT_NAME}") + + # FIXME: maybe only save one copy + paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + + if self.do_grad_scaling: + paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cuda": [k.current_seed() for k in paddle.get_rng_state()], + "cpu": paddle.framework.core.default_cpu_generator().get_state().current_seed(), + } + # if self.args.use_hybrid_parallel: + # rng_states[ + # "hybrid_parallel_rng_state_tracker" + # ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() + + if self.args.world_size > 1: + rng_states_list = [] + paddle.distributed.all_gather_object(rng_states_list, rng_states) + if self.args.should_save: + os.makedirs(output_dir, exist_ok=True) + paddle.save(rng_states_list, os.path.join(output_dir, f"rng_state_{self.args.world_size}.pth")) + else: + os.makedirs(output_dir, exist_ok=True) + paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + + if strtobool(os.getenv("FLAG_LLM_PDC", "False")): + # save checkpoint_done file to ensure checkpoint is complete + if self.args.should_save_model_state and self.args.should_save: + # For ckpt integrity + paddle.save(self.state.global_step, os.path.join(output_dir, ".checkpoint_done")) + + def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + if self.args.should_save: + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + if self.args.should_save_model_state: + self._save_ckpt_func(self.model.state_dict(), os.path.join(output_dir, MODEL_NAME), self.comm_group_in_dp) + logger.info(f"Model weights saved in {output_dir}/{MODEL_NAME}") diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 13982d33ee16..ab7c42d0af81 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -602,51 +602,27 @@ def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint): self._load_from_checkpoint(resume_from_checkpoint) return model - def train( - self, - resume_from_checkpoint: Optional[Union[str, bool]] = None, - ignore_keys_for_eval: Optional[List[str]] = None, - ): - """ - Main training entry point. - - Args: - resume_from_checkpoint (`str` or `bool`, *optional*): - If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a - `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance - of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. - ignore_keys_for_eval (`List[str]`, *optional*) - A list of keys in the output of your model (if it is a dictionary) that should be ignored when - gathering predictions for evaluation during the training. - """ - args = self.args - self.is_in_train = True - + def _sync_resume_states(self, resume_from_checkpoint): logger.info(f"Starting training from resume_from_checkpoint : {resume_from_checkpoint}") # The resume_from_checkpoint could be None in some machine node. # Here we reset None to temp directory. - if args.world_size > 1: + if self.args.world_size > 1: is_resume_from_checkpoint = paddle.to_tensor([resume_from_checkpoint is not None]) paddle.distributed.all_reduce(is_resume_from_checkpoint) is_resume_from_checkpoint = is_resume_from_checkpoint.item() + if is_resume_from_checkpoint > 0 and is_resume_from_checkpoint < paddle.distributed.get_world_size(): if resume_from_checkpoint is None: resume_from_checkpoint = os.path.join(self.args.output_dir, "local_tempdir") if os.path.exists(resume_from_checkpoint) and self.args.local_rank == 0: shutil.rmtree(resume_from_checkpoint) os.makedirs(resume_from_checkpoint, exist_ok=True) - logger.info(f"Reset resume_from_checkpoint to temp directory : {resume_from_checkpoint}") - # memory metrics - must set up as early as possible - self._memory_tracker.start() - - if not self.args.should_load_sharding_stage1_model: - self._load_from_checkpoint(resume_from_checkpoint) + logger.info(f"Reset resume_from_checkpoint to temp directory : {resume_from_checkpoint}") - train_dataloader = self.get_train_dataloader() + def _get_train_steps_and_samples(self, args, train_dataloader, total_train_batch_size): - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size len_dataloader = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) @@ -683,6 +659,56 @@ def train( f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}" ) + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Total num train samples = {num_train_samples:,}") + + return len_dataloader, max_steps, num_train_epochs, num_update_steps_per_epoch, num_examples, num_train_samples + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + """ + args = self.args + self.is_in_train = True + + self._sync_resume_states(resume_from_checkpoint) + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + if not self.args.should_load_sharding_stage1_model: + self._load_from_checkpoint(resume_from_checkpoint) + + train_dataloader = self.get_train_dataloader() + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size + ( + len_dataloader, + max_steps, + num_train_epochs, + num_update_steps_per_epoch, + num_examples, + num_train_samples, + ) = self._get_train_steps_and_samples(args, train_dataloader, total_train_batch_size) + # delay_optimizer_creation = ( # self.sharding is not None # and ShardingOption.SHARD_OP in self.args.sharding @@ -1170,13 +1196,17 @@ def _print_timer(self): if timer_info or paddle_timer_info: logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}") + def _get_item_from_loss(self, loss): + assert isinstance(loss, paddle.Tensor) and loss._is_initialized() + return loss.item() + def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): if self.control.should_log: logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes - tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean()) # reset tr_loss to zero tr_loss.subtract_(tr_loss) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 601bc549cb9d..299fe29ee202 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1574,6 +1574,8 @@ def should_save(self): if self.save_on_each_node: return self.local_process_index == 0 else: + if self.use_auto_parallel: + return self.data_parallel_rank == 0 return self.process_index == 0 @property @@ -1593,7 +1595,7 @@ def should_save_model_state(self): if self.should_save_sharding_stage1_model: return True elif self.use_auto_parallel: - return True + return self.data_parallel_rank == 0 elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0 diff --git a/paddlenlp/transformers/llama/modeling_3D_auto.py b/paddlenlp/transformers/llama/modeling_3D_auto.py index 62da0545ae2d..1f72d3a68f6d 100644 --- a/paddlenlp/transformers/llama/modeling_3D_auto.py +++ b/paddlenlp/transformers/llama/modeling_3D_auto.py @@ -68,6 +68,7 @@ __all__ = [ "LlamaForCausalLM3DAuto", + "LlamaPretrainingCriterion3DAuto", ] @@ -1018,7 +1019,7 @@ def forward( ) -class LlamaPretrainingCriterionAuto(paddle.nn.Layer): +class LlamaPretrainingCriterion3DAuto(paddle.nn.Layer): """ Criterion for Llama. It calculates the final loss. @@ -1026,7 +1027,7 @@ class LlamaPretrainingCriterionAuto(paddle.nn.Layer): def __init__(self, config): - super(LlamaPretrainingCriterionAuto, self).__init__() + super(LlamaPretrainingCriterion3DAuto, self).__init__() self.ignore_index = getattr(config, "ignore_index", -100) self.config = config self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output @@ -1041,23 +1042,23 @@ def forward(self, prediction_scores, masked_lm_labels): self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) # Force Replicated to match dy & st - prediction_scores1 = dist.reshard( + prediction_scores = dist.reshard( prediction_scores, get_mesh(-1), [dist.Replicate(), dist.Replicate()], ) - masked_lm_labels1 = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()]) + masked_lm_labels = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()]) # Force entropy same kernel - if isinstance(prediction_scores1, paddle.Tensor): + if isinstance(prediction_scores, paddle.Tensor): masked_lm_loss = self.loss_func( - prediction_scores1.astype("float32")._use_gpudnn(False), - masked_lm_labels1.unsqueeze(2), + prediction_scores.astype("float32")._use_gpudnn(False), + masked_lm_labels.unsqueeze(2), ) else: masked_lm_loss = self.loss_func( - prediction_scores1.astype("float32"), - masked_lm_labels1.unsqueeze(2), + prediction_scores.astype("float32"), + masked_lm_labels.unsqueeze(2), ) masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") @@ -1096,7 +1097,6 @@ def __init__(self, config): self.llama = LlamaModelAuto(config) self.lm_head = LlamaLMHeadAuto(config) - self.criterion = LlamaPretrainingCriterionAuto(config) def get_input_embeddings(self): return self.llama.embed_tokens @@ -1220,19 +1220,21 @@ def forward( logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) - loss = None - if labels is not None: - labels.stop_gradient = True - loss = self.criterion(logits, labels) + return logits - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + # loss = None + # if labels is not None: + # labels.stop_gradient = True + # loss = self.criterion(logits, labels) + + # if not return_dict: + # output = (logits,) + outputs[1:] + # return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithCrossAttentions( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) From fe6b45d8666e466da6c4ebb6f5c64e9baebd7e4b Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Mon, 29 Jan 2024 05:50:58 +0000 Subject: [PATCH 4/7] refactor traning loop --- .../auto_parallel/run_pretrain_3D_auto.py | 4 +-- paddlenlp/trainer/auto_trainer.py | 35 ++++++++++++------- paddlenlp/trainer/training_args.py | 4 +++ .../transformers/llama/modeling_3D_auto.py | 2 +- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.py b/llm/llama/auto_parallel/run_pretrain_3D_auto.py index 730dba2d66aa..e53e7cefe18e 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.py @@ -367,8 +367,8 @@ class PretrainingTrainer(SemiAutoTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def get_train_dataloader(self): - dist_loader = super().get_train_dataloader() + def _wrap_for_dist_loader(self, train_dataloader): + dist_loader = super()._wrap_for_dist_loader(train_dataloader) dist_loader._input_keys = ["input_ids", "labels"] return dist_loader diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index bc5837df1a1c..a288d12dcab2 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -82,6 +82,9 @@ def _nested_gather(self, tensors): """ return tensors + def _wrap_model(self, model, training=True): + return model + def _get_meshes_for_loader(self): def _get_mesh(pp_idx=0): return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] @@ -91,12 +94,25 @@ def _get_mesh(pp_idx=0): meshes.append(_get_mesh(pp_idx)) return meshes + def _wrap_for_dist_loader(self, train_dataloader): + dist_loader = dist.shard_dataloader( + dataloader=train_dataloader, + meshes=self._get_meshes_for_loader(), + shard_dims="dp", + ) + return dist_loader + def _wrap_for_auto(self, model, train_dataloader): + dist_loader = self._wrap_for_dist_loader(train_dataloader) + if self.args.run_static_semi_auto: - return dist.to_static(model, train_dataloader, self.criterion, self.optimizer, strategy=self.args.strategy) + return ( + dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy), + dist_loader, + ) else: self.optimizer = dist.shard_optimizer(self.optimizer) - return model + return model, dist_loader def _wrap_for_amp_training(self): pass @@ -213,6 +229,9 @@ def _inner_training_loop( npu_accelerate_plugin(self.optimizer) + model, dist_loader = self._wrap_for_auto(model, train_dataloader) + train_dataloader = dist_loader() + self.timers and self.timers("read-data").start() for epoch in range(epochs_trained, num_train_epochs): @@ -224,7 +243,7 @@ def _inner_training_loop( step_control = 0 # used in loop control, reset to 0 after every step self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) - for step, inputs in enumerate(epoch_iterator): + for step, inputs in enumerate(train_dataloader): if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: inputs = split_inputs_sequence_dim(inputs) self.timers and self.timers("read-data").stop() @@ -359,16 +378,6 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: drop_last=self.args.dataloader_drop_last, ) - def get_train_dataloader(self): - train_dataloader = super().get_train_dataloader() - dist_loader = dist.shard_dataloader( - dataloader=train_dataloader, - meshes=self._get_meshes_for_loader(), - shard_dims="dp", - ) - - return dist_loader - def dynamic_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 811ce09c73eb..8ecc7ab77b5b 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1609,12 +1609,16 @@ def _no_sync_in_gradient_accumulation(self): @property def should_save_sharding_stage1_model(self): + if self.use_auto_parallel: + return False return ( ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model ) @property def should_load_sharding_stage1_model(self): + if self.use_auto_parallel: + return False return ( ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model ) diff --git a/paddlenlp/transformers/llama/modeling_3D_auto.py b/paddlenlp/transformers/llama/modeling_3D_auto.py index 1f72d3a68f6d..7b1a6345bcdb 100644 --- a/paddlenlp/transformers/llama/modeling_3D_auto.py +++ b/paddlenlp/transformers/llama/modeling_3D_auto.py @@ -1045,7 +1045,7 @@ def forward(self, prediction_scores, masked_lm_labels): prediction_scores = dist.reshard( prediction_scores, get_mesh(-1), - [dist.Replicate(), dist.Replicate()], + [dist.Replicate(), dist.Replicate(), dist.Replicate()], ) masked_lm_labels = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()]) From b1265e3dac890ede138f3932d2c0d3a1b427888d Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Tue, 30 Jan 2024 08:59:58 +0000 Subject: [PATCH 5/7] refine args of auto trainer --- llm/llama/auto_parallel/run_auto.sh | 2 +- llm/llama/auto_parallel/run_auto_sp.sh | 2 +- .../auto_parallel/run_pretrain_3D_auto.py | 4 +- .../auto_parallel/run_pretrain_3D_auto.sh | 2 +- .../auto_parallel/run_pretrain_3D_hand.py | 1 - llm/llama/auto_parallel/run_pretrain_auto.py | 6 +- paddlenlp/trainer/auto_trainer.py | 18 +++--- paddlenlp/trainer/trainer.py | 4 +- paddlenlp/trainer/training_args.py | 59 ++++++++----------- scripts/distribute/ci_case_auto.sh | 16 ++--- 10 files changed, 51 insertions(+), 63 deletions(-) diff --git a/llm/llama/auto_parallel/run_auto.sh b/llm/llama/auto_parallel/run_auto.sh index f8a114870dad..27fb9b61fe2e 100644 --- a/llm/llama/auto_parallel/run_auto.sh +++ b/llm/llama/auto_parallel/run_auto.sh @@ -68,6 +68,6 @@ python -u -m paddle.distributed.launch \ --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" + --enable_auto_parallel 1 # --resume_from_checkpoint "output/llama_auto_serial/checkpoint-2" \ diff --git a/llm/llama/auto_parallel/run_auto_sp.sh b/llm/llama/auto_parallel/run_auto_sp.sh index 4e13d1fbfb21..e45a47bf2a64 100644 --- a/llm/llama/auto_parallel/run_auto_sp.sh +++ b/llm/llama/auto_parallel/run_auto_sp.sh @@ -68,7 +68,7 @@ python -u -m paddle.distributed.launch \ --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ --sequence_parallel true \ # --resume_from_checkpoint "output/llama_auto_serial/checkpoint-2" \ diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.py b/llm/llama/auto_parallel/run_pretrain_3D_auto.py index ad5511fd7691..0866dce8fada 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.py @@ -112,7 +112,7 @@ class PreTrainingArguments(TrainingArguments): def __post_init__(self): super().__post_init__() - assert self.use_auto_parallel + assert self.enable_auto_parallel # NOTE(gongenlei): new add autotuner_benchmark if self.autotuner_benchmark: @@ -402,7 +402,7 @@ def init_seed(seed: int = 1234, args=None): np.random.seed(seed) paddle.seed(seed) else: - assert not args.use_hybrid_parallel and args.use_auto_parallel + assert not args.use_hybrid_parallel and args.enable_auto_parallel if dist.get_world_size() > 1: topo = Topology( dist.get_rank(), diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.sh b/llm/llama/auto_parallel/run_pretrain_3D_auto.sh index 46a158bdeb79..a1b0cd35ef54 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_auto.sh +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.sh @@ -74,5 +74,5 @@ python3.8 -u -m paddle.distributed.launch \ --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ --max_grad_norm 1.0 \ diff --git a/llm/llama/auto_parallel/run_pretrain_3D_hand.py b/llm/llama/auto_parallel/run_pretrain_3D_hand.py index 7a353d3f9bbf..ae58bdc146e4 100644 --- a/llm/llama/auto_parallel/run_pretrain_3D_hand.py +++ b/llm/llama/auto_parallel/run_pretrain_3D_hand.py @@ -74,7 +74,6 @@ class PreTrainingArguments(TrainingArguments): "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." }, ) - parallel_mode: str = field(default="hybrid", metadata={"help": ""}) @dataclass diff --git a/llm/llama/auto_parallel/run_pretrain_auto.py b/llm/llama/auto_parallel/run_pretrain_auto.py index 59584e45984f..f3e171da47c9 100644 --- a/llm/llama/auto_parallel/run_pretrain_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_auto.py @@ -99,7 +99,6 @@ class PreTrainingArguments(TrainingArguments): "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." }, ) - parallel_mode: str = field(default="hybrid", metadata={"help": ""}) fused_linear_param_grad_add: bool = field( default=False, metadata={ @@ -114,7 +113,6 @@ class PreTrainingArguments(TrainingArguments): default=-1, metadata={"help": "The step to end job_schedule_profiler."}, ) - parallel_mode: str = field(default="hybrid", metadata={"help": ""}) pipeline_schedule_mode: str = field( default="1F1B", metadata={"help": "The pipeline schedule mode, support FThenB, 1F1B, VPP and Eager-1F1B."} ) @@ -128,7 +126,7 @@ class PreTrainingArguments(TrainingArguments): def __post_init__(self): super().__post_init__() - assert self.use_auto_parallel + assert self.enable_auto_parallel if self.fused_linear_param_grad_add: fused_passes = self.strategy.fused_passes fused_passes.enable = True @@ -414,7 +412,7 @@ def init_seed(seed: int = 1234, args=None): np.random.seed(seed) paddle.seed(seed) else: - assert not args.use_hybrid_parallel and args.use_auto_parallel + assert not args.use_hybrid_parallel and args.enable_auto_parallel if dist.get_world_size() > 1: topo = Topology( dist.get_rank(), diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 1ff3a9b26198..90226f3f6d9e 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -48,7 +48,7 @@ class AutoTrainer(Trainer): def __init__(self, *args, **kwargs): - if kwargs.get("args", None) is not None and kwargs["args"].run_static_auto: + if kwargs.get("args", None) is not None and kwargs["args"].to_static: if kwargs.get("criterion", None) is None: def loss_func(loss, outputs): @@ -57,7 +57,7 @@ def loss_func(loss, outputs): kwargs.update({"criterion": loss_func}) super().__init__(*args, **kwargs) - assert self.args.use_auto_parallel + assert self.args.enable_auto_parallel self.global_mesh = fleet.auto.get_mesh() @@ -107,7 +107,7 @@ def _wrap_for_dist_loader(self, train_dataloader): def _wrap_for_auto(self, model, train_dataloader): dist_loader = self._wrap_for_dist_loader(train_dataloader) - if self.args.run_static_auto: + if self.args.to_static: return ( dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy), dist_loader, @@ -136,8 +136,8 @@ def _split_batches_for_accumulation(self, inputs): if self.args.gradient_accumulation_steps == 1: return [inputs] - # if self.args.run_static_auto: - if self.args.run_static_auto and self.args.pipeline_parallel_degree > 1: + # if self.args.to_static: + if self.args.to_static and self.args.pipeline_parallel_degree > 1: return [inputs] local_batches = [{} for i in range(self.args.gradient_accumulation_steps)] @@ -281,8 +281,8 @@ def _inner_training_loop( with _exec_mode_guard("dynamic"): tr_loss += tr_loss_step - disable_accumulation = self.args.pipeline_parallel_degree > 1 and self.args.run_static_auto - # disable_accumulation = self.args.run_static_auto + disable_accumulation = self.args.pipeline_parallel_degree > 1 and self.args.to_static + # disable_accumulation = self.args.to_static if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps @@ -399,7 +399,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, inputs = self._prepare_inputs(inputs) - if not self.args.run_static_auto: + if not self.args.to_static: loss = self.dynamic_traning(model, inputs) else: loss = self.static_traning(model, inputs) @@ -414,7 +414,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, return float(loss) def optimizer_step(self): - if not self.args.run_static_auto: + if not self.args.to_static: optimizer_was_run = True if self.do_grad_scaling: scale_before = paddle.assign(self.scaler._scale) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 0d8fdd02d809..8adc97008845 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -345,8 +345,8 @@ def __init__( ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) - self._save_ckpt_func = dist.save_state_dict if self.args.use_auto_parallel else paddle.save - self._load_ckpt_func = dist.load_state_dict if self.args.use_auto_parallel else paddle.load + self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save + self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index b4e08858b058..1944b5226300 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -724,7 +724,7 @@ class TrainingArguments: ) to_static: Optional[bool] = field( default=False, - metadata={"help": "Enable training under @to_static."}, + metadata={"help": ("Whether to train model under static mode by jit.to_static or distributed.to_static.")}, ) unified_checkpoint_config: Optional[str] = field( default="", @@ -748,18 +748,10 @@ class TrainingArguments: default=False, metadata={"help": "reshard pp even if pp degree in the model and pp degree in script match"}, ) - parallel_mode: str = field( - default="hybrid", - metadata={ - "help": ( - "Which parallel mode to use for distributed training.\n" - "Following options are supports:\n" - "- hybrid: under the hybrid parallel mode with combined distributed strategies.\n" - "- auto: under the auto parallel mode with AutoTrainer \n" - ) - }, + enable_auto_parallel: Optional[bool] = field( + default=False, + metadata={"help": "whether to run distributed training in auto parallel mode"}, ) - run_static_auto: bool = field(default=True, metadata={"help": "whether to run static graph in auto parallel mode"}) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) @@ -847,7 +839,6 @@ def __post_init__(self): self.optim = OptimizerNames(self.optim) self.use_hybrid_parallel = False - self.use_auto_parallel = False if isinstance(self.sharding, bool): self.sharding = "stage1" if self.sharding else "" @@ -870,13 +861,9 @@ def __post_init__(self): if len(self.sharding) == 0 and self.sharding_parallel_degree > 0: warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.") - try: - self.use_auto_parallel = self.parallel_mode == "auto" - except: - pass + world_size = paddle.distributed.get_world_size() - if paddle.distributed.get_world_size() > 1: - world_size = paddle.distributed.get_world_size() + if world_size > 1: tensor_parallel_degree = max(self.tensor_parallel_degree, 1) sep_parallel_degree = max(self.sep_parallel_degree, 1) pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) @@ -919,9 +906,16 @@ def __post_init__(self): self.pipeline_parallel_degree = -1 self.sep_parallel_degree = -1 - if self.use_hybrid_parallel and self.use_auto_parallel: + if self.use_hybrid_parallel and self.enable_auto_parallel: self.use_hybrid_parallel = False + if self.to_static: + assert world_size == 1 or self.enable_auto_parallel, ( + "It's not supported for training in static mode except the following cases : " + "1. world_size == 1, which means single-card training while no parallelism is used; " + "2. enable_auto_parallel is set to True, which means the training will be executed in static mode of auto parallel." + ) + if self.distributed_dataloader and not (self.tensor_parallel_degree > 1 or self.pipeline_parallel_degree > 1): warnings.warn("We set `distributed_dataloader` to False if tp_degree <= 1 and pp_degree <= 1") self.distributed_dataloader = False @@ -933,7 +927,6 @@ def __post_init__(self): # use_hybrid_parallel if self.use_hybrid_parallel: - world_size = paddle.distributed.get_world_size() if ShardingOption.OFFLOAD in self.sharding: warnings.warn("`offload` is not supported NOW!") @@ -1138,8 +1131,7 @@ def is_segment_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) - elif self.use_auto_parallel: - world_size = paddle.distributed.get_world_size() + elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) @@ -1290,7 +1282,6 @@ def is_segment_parallel_supported(): mesh_dims = list(zip(order, degree)) fleet.auto.create_mesh(mesh_dims) else: - world_size = paddle.distributed.get_world_size() if world_size > 1: if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(): if self.unified_checkpoint: @@ -1404,7 +1395,7 @@ def data_parallel_rank(self): if dp_group.rank == -1: return 0 return dp_group.rank - elif self.use_auto_parallel: + elif self.enable_auto_parallel: mesh = fleet.auto.get_mesh() return mesh.get_rank_by_dim_and_process_id("dp", dist.get_rank()) else: @@ -1414,7 +1405,7 @@ def data_parallel_rank(self): def dataset_rank(self): if self.use_hybrid_parallel: return max(self.sharding_parallel_degree, 1) * self.data_parallel_rank + self.sharding_parallel_rank - elif self.use_auto_parallel: + elif self.enable_auto_parallel: return self.data_parallel_rank else: return paddle.distributed.get_rank() @@ -1423,7 +1414,7 @@ def dataset_rank(self): def dataset_world_size(self): if self.use_hybrid_parallel: return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1) - elif self.use_auto_parallel: + elif self.enable_auto_parallel: return max(self.data_parallel_degree, 1) else: return paddle.distributed.get_world_size() @@ -1443,7 +1434,7 @@ def tensor_parallel_rank(self): hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() return max(tp_group.rank, 0) - elif self.use_auto_parallel: + elif self.enable_auto_parallel: mesh = fleet.auto.get_mesh() return mesh.get_rank_by_dim_and_process_id("mp", dist.get_rank()) else: @@ -1455,7 +1446,7 @@ def pipeline_parallel_rank(self): hcg = fleet.get_hybrid_communicate_group() rank = hcg.get_stage_id() return max(rank, 0) - elif self.use_auto_parallel: + elif self.enable_auto_parallel: mesh = fleet.auto.get_mesh() return mesh.get_rank_by_dim_and_process_id("pp", dist.get_rank()) else: @@ -1561,7 +1552,7 @@ def should_log(self): """ Whether or not the current process should produce log. """ - if self.use_auto_parallel: + if self.enable_auto_parallel: return True elif self.log_on_each_node: return self.local_process_index == 0 @@ -1582,7 +1573,7 @@ def should_save(self): if self.save_on_each_node: return self.local_process_index == 0 else: - if self.use_auto_parallel: + if self.enable_auto_parallel: return True return self.process_index == 0 @@ -1602,7 +1593,7 @@ def should_save_model_state(self): else: if self.should_save_sharding_stage1_model: return True - elif self.use_auto_parallel: + elif self.enable_auto_parallel: return True elif self.use_hybrid_parallel: # save on dataset rank 0 @@ -1619,7 +1610,7 @@ def _no_sync_in_gradient_accumulation(self): @property def should_save_sharding_stage1_model(self): - if self.use_auto_parallel: + if self.enable_auto_parallel: return False return ( ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model @@ -1627,7 +1618,7 @@ def should_save_sharding_stage1_model(self): @property def should_load_sharding_stage1_model(self): - if self.use_auto_parallel: + if self.enable_auto_parallel: return False return ( ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index f33f7e39f3a8..f68f45604f44 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -890,7 +890,7 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() { --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 @@ -956,7 +956,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() { --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 @@ -1022,7 +1022,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() { --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 @@ -1088,7 +1088,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() { --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.2 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 @@ -1156,7 +1156,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2 --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 @@ -1225,7 +1225,7 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2 --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ + --enable_auto_parallel 1 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 @@ -1290,8 +1290,8 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { --do_eval \ --device "gpu" \ --data_impl "mmap" \ - --parallel_mode "auto" \ - --run_static_auto 0 \ + --enable_auto_parallel 1 \ + --to_static 0 \ --max_grad_norm 1.0 \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.2 | grep 'global_step 10' | awk -F '; loss' '{print $2}' | awk -F 'lr' '{print $1}'` From 5509d9aaa6b39f834502b479ea171843aeef0e69 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Wed, 31 Jan 2024 03:31:08 +0000 Subject: [PATCH 6/7] broadcast loss --- paddlenlp/trainer/auto_trainer.py | 20 +++++++++++--------- paddlenlp/trainer/training_args.py | 10 ++++++++++ scripts/distribute/ci_case_auto.sh | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 90226f3f6d9e..2413855c4b4d 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -60,10 +60,7 @@ def loss_func(loss, outputs): assert self.args.enable_auto_parallel self.global_mesh = fleet.auto.get_mesh() - - self.mesh_in_dp = self.global_mesh.get_mesh_with_dim("dp")[self.args.data_parallel_rank] - self.mesh_for_pp = self.mesh_in_dp.get_mesh_with_dim("mp")[self.args.tensor_parallel_rank] - self.comm_group_in_pp = dist.new_group(list(self.mesh_for_pp.process_ids)) + self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group() def _nested_gather(self, tensors): """ @@ -79,8 +76,7 @@ def _nested_gather(self, tensors): if self.args.pipeline_parallel_degree <= 1: return super()._nested_gather(tr_loss) - assert len(self.comm_group_in_pp.ranks) >= 2 - paddle.distributed.broadcast(tr_loss, src=max(self.comm_group_in_pp.ranks), group=self.comm_group_in_pp).wait() + paddle.distributed.broadcast(tr_loss, src=self.comm_group_in_pp.ranks[-1], group=self.comm_group_in_pp) return super()._nested_gather(tr_loss) @@ -143,9 +139,12 @@ def _split_batches_for_accumulation(self, inputs): local_batches = [{} for i in range(self.args.gradient_accumulation_steps)] for key, value in inputs.items(): - local_datas = value.split(self.args.gradient_accumulation_steps, axis=0) + ori_mesh, ori_placements = value.process_mesh, value.placements + replicate_value = dist.reshard(value, ori_mesh, [dist.Replicate(), dist.Replicate()]) + local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0) + for index, data in enumerate(local_datas): - local_batches[index].update({key: data}) + local_batches[index].update({key: dist.reshard(data, ori_mesh, ori_placements)}) return local_batches @@ -378,7 +377,7 @@ def dynamic_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) - if self.args.gradient_accumulation_steps > 1: + if loss is not None and self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps if self.do_grad_scaling: @@ -392,6 +391,9 @@ def static_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, input_ids, labels = tuple(inputs.values()) loss = model(input_ids, labels) + if loss is not None and self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + return loss def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 1944b5226300..2cc319f14e7b 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1281,6 +1281,16 @@ def is_segment_parallel_supported(): degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree] mesh_dims = list(zip(order, degree)) fleet.auto.create_mesh(mesh_dims) + + # init hcg for communication in trainer + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_degree, + "mp_degree": self.tensor_parallel_degree, + "pp_degree": self.pipeline_parallel_degree, + } + fleet.init(is_collective=True, strategy=strategy) + else: if world_size > 1: if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(): diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index f68f45604f44..9cea90fac86f 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1298,7 +1298,7 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" - loss_base=9.543781280517578 + loss_base=9.60352325 ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} From f13a0bf9ed3edc62ba1279b482262ed50bfb7c80 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Wed, 31 Jan 2024 08:49:24 +0000 Subject: [PATCH 7/7] add auto ci cases --- scripts/distribute/ci_case_auto.sh | 150 +++++++++++++++++++++++++++-- 1 file changed, 143 insertions(+), 7 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 9cea90fac86f..109594687059 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -45,7 +45,9 @@ function gpt_case_list_auto() { } function llama_case_list_auto() { - llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2 + llama_dygraph_auto_bs8_fp32_DP2 + llama_dygraph_auto_bs8_fp32_DP2-MP2 + llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2 llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1 llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1 @@ -1238,13 +1240,147 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2 echo "=========== $FUNCNAME run end ===========" } -function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { +function llama_dygraph_auto_bs8_fp32_DP2() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 - task_name="llama_auto_bs16_dp2mp2pp2" + task_name="llama_auto_bs8_dp2" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch --gpus "0,1" --log_dir $case_log_dir run_pretrain_3D_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --hidden_size 1024 \ + --intermediate_size 3072 \ + --num_hidden_layers 8 \ + --num_attention_heads 32 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 0 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --pipeline_parallel_degree 1 \ + --tensor_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 10 \ + --save_steps 5000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --sharding "" \ + --eval_steps 1000000 \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --enable_auto_parallel 1 \ + --to_static 0 \ + --max_grad_norm 1.0 \ + >>${log_path}/$FUNCNAME 2>&1 + loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=-1 + mem=-1 + echo "result: loss=$loss ips=$ips mem=$mem" + loss_base=9.52781677 + ips_base=-1 + mem_base=-1 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} + +function llama_dygraph_auto_bs8_fp32_DP2-MP2() { + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + + task_name="llama_auto_bs8_dp2mp2" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch --gpus "0,1,2,3" --log_dir $case_log_dir run_pretrain_3D_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --hidden_size 1024 \ + --intermediate_size 3072 \ + --num_hidden_layers 8 \ + --num_attention_heads 32 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 0 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --pipeline_parallel_degree 1 \ + --tensor_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 10 \ + --save_steps 5000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --sharding "" \ + --eval_steps 1000000 \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --enable_auto_parallel 1 \ + --to_static 0 \ + --max_grad_norm 1.0 \ + >>${log_path}/$FUNCNAME 2>&1 + loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=-1 + mem=-1 + echo "result: loss=$loss ips=$ips mem=$mem" + loss_base=9.40659046 + ips_base=-1 + mem_base=-1 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} + +function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + + task_name="llama_auto_bs8_dp2mp2pp2" case_out_dir="output/$task_name" case_log_dir="output/$task_name""_log" rm -rf $case_out_dir @@ -1263,8 +1399,8 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { --num_hidden_layers 8 \ --num_attention_heads 32 \ --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 2 \ - --gradient_accumulation_steps 2 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ --use_flash_attention 0 \ --use_fused_rms_norm 0 \ --fp16 0 \ @@ -1294,11 +1430,11 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { --to_static 0 \ --max_grad_norm 1.0 \ >>${log_path}/$FUNCNAME 2>&1 - loss=`cat $case_log_dir/workerlog.2 | grep 'global_step 10' | awk -F '; loss' '{print $2}' | awk -F 'lr' '{print $1}'` + loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" - loss_base=9.60352325 + loss_base=9.38319206 ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}