diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.py b/llm/llama/auto_parallel/run_pretrain_3D_auto.py new file mode 100644 index 000000000000..1d8bbe8b73ea --- /dev/null +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.py @@ -0,0 +1,723 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPT/Llama auto parallel pretraining scripts. +""" +import os +import random +import sys +import types +from dataclasses import dataclass, field +from typing import List, Optional + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle.io import DataLoader, DistributedBatchSampler + +from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments +from paddlenlp.transformers import ( + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, + LlamaConfig, + LlamaForCausalLM3DAuto, +) +from paddlenlp.utils.log import logger + +MODEL_CLASSES = { + "llama": (LlamaConfig, LlamaForCausalLM3DAuto), +} + + +from collections import OrderedDict + +from paddlenlp.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + parallel_mode: str = field(default="hybrid", metadata={"help": ""}) + + pipeline_schedule_mode: str = field( + default="1F1B", metadata={"help": "The pipeline schedule mode, support FThenB, 1F1B, VPP and Eager-1F1B."} + ) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_type: Optional[str] = field( + default="llama", metadata={"help": "Only support for llama pre-training for now."} + ) + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + vocab_size: Optional[int] = field( + default=None, + metadata={ + "help": ".Vocabulary size of the Llama model. Defines the number of different tokens that can be represented by the `inputs_ids`" + }, + ) + hidden_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the hidden representations."}) + intermediate_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the MLP representations."}) + num_hidden_layers: Optional[int] = field( + default=None, metadata={"help": "Number of hidden layers in the Transformer encoder."} + ) + num_attention_heads: Optional[int] = field( + default=None, + metadata={"help": "Number of attention heads for each attention layer in the Transformer encoder."}, + ) + use_flash_attention: bool = field( + default=False, + metadata={"help": "use_flash_attention"}, + ) + use_fused_rms_norm: bool = field( + default=False, + metadata={"help": "llama, use_fused_rms_norm"}, + ) + fuse_attention_qkv: bool = field( + default=False, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + recompute_granularity: str = field( + default="full", + metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"}, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models." + }, + ) + sequence_parallel: bool = field( + default=False, + metadata={"help": "whether to use sequence parallel"}, + ) + fuse_sequence_parallel_allreduce: bool = field( + default=False, + metadata={"help": "whether to use fuse sequence parallel allreduce"}, + ) + use_fused_rope: Optional[bool] = field( + default=False, + metadata={"help": "Enable rope fusion or not."}, + ) + no_recompute_layers: Optional[List[int]] = field( + default=None, + metadata={"help": "Specify the full transformer layers that should not be recomputed."}, + ) + pp_recompute_interval: int = field( + default=1, + metadata={ + "help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0." + }, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.data_parallel_degree + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.data_parallel_degree + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + print("====data seed====", training_args.seed) + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + + logger.info(tokenizer._decode(input_ids)) + + from paddlenlp.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = tokens_[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] # add + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +def create_optimizer(model, lr_scheduler, training_args): + decay_parameters = [ + p.name + for n, p in model.named_parameters() + if (not any(nd in n for nd in ["bias", "norm"])) or n == "llama.norm.weight" + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + optimizer = optimizer_cls( + learning_rate=lr_scheduler if lr_scheduler is None else lr_scheduler, + apply_decay_param_fun=apply_decay_param_fun, + parameters=model.parameters(), + weight_decay=training_args.weight_decay, + grad_clip=paddle.nn.ClipGradByGlobalNorm(training_args.max_grad_norm) + if training_args.max_grad_norm > 0 + else None, + **optimizer_kwargs, + ) + + return optimizer + + +def print_config(args, key=""): + """ + print config values + """ + logger.info("=" * 60) + if args is None: + args = args + key = "Training" + import paddlenlp + + logger.info("{:^40}".format("{} Configuration Arguments".format(key))) + logger.info("{:30}: {}".format("paddle commit id", paddle.version.commit)) + logger.info("{:30}: {}".format("paddlenlp commit id", paddlenlp.version.commit)) + + for a in dir(args): + if a[:2] != "__": # don't print double underscore methods + v = getattr(args, a) + if not isinstance(v, types.MethodType): + logger.info("{:30}: {}".format(a, v)) + + logger.info("") + + +def init_seed(seed: int = 1234, args=None): + if args is None: + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + if args is not None: + if args.use_hybrid_parallel: + from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker + + random.seed(args.seed + args.dataset_rank) + np.random.seed(args.seed + args.dataset_rank) + paddle.seed(args.seed + args.dataset_rank) + + # local_seed/ global_seed is used to control dropout in ModelParallel + local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000 + global_seed = args.seed + 100003 + args.dataset_rank + tracker = get_rng_state_tracker() + + if "global_seed" not in tracker.states_: + tracker.add("global_seed", global_seed) + if "local_seed" not in tracker.states_: + tracker.add("local_seed", local_seed) + else: + random.seed(args.seed) + np.random.seed(args.seed) + paddle.seed(args.seed) + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp")[pp_idx] + return mesh + + +def shard_fn(layer, mesh_idx, placements): + paran_name = layer.weight.name + layer.weight = dist.shard_tensor(layer.weight, get_mesh(mesh_idx), placements) + layer.weight.name = paran_name + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.enable_linear_fused_grad_add: + from fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + init_seed(args=training_args) + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + config_class, model_class = MODEL_CLASSES[model_args.model_type] + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + + config = config_class.from_pretrained(model_args.model_name_or_path) + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + if model_args.no_recompute_layers is not None: + model_args.no_recompute_layers.sort() + + config.hidden_size = model_args.hidden_size if model_args.hidden_size is not None else config.hidden_size + config.intermediate_size = ( + model_args.intermediate_size if model_args.intermediate_size is not None else config.intermediate_size + ) + config.num_hidden_layers = ( + model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers + ) + config.num_attention_heads = ( + model_args.num_attention_heads if model_args.num_attention_heads is not None else config.num_attention_heads + ) + + config.use_flash_attention = model_args.use_flash_attention + config.use_fused_rms_norm = model_args.use_fused_rms_norm + config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_attention_ffn = model_args.fuse_attention_ffn + config.recompute_granularity = model_args.recompute_granularity + config.virtual_pp_degree = model_args.virtual_pp_degree + config.sequence_parallel = model_args.sequence_parallel + config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce + config.use_fused_rope = model_args.use_fused_rope + config.no_recompute_layers = model_args.no_recompute_layers + config.pp_recompute_interval = model_args.pp_recompute_interval + config.recompute_use_reentrant = model_args.recompute_use_reentrant + + config.use_recompute = training_args.recompute + config.tensor_parallel_degree = training_args.tensor_parallel_degree + config.tensor_parallel_rank = training_args.tensor_parallel_rank + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + print("======M M M M======", model_class) + model = model_class._from_config(config, dtype=dtype) + # load model + # load_model(model) + shard_model(model) + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, _, _, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + optimizer = create_optimizer(model, lr_scheduler, training_args) + + def loss_func(loss, outputs): + return loss + + print_config(training_args) + + # create sampler and dataloader + # each rank read (training_args.per_device_train_batch_size * training_args.data_parallel_degree) samples + print( + "dp_rank: ", dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree) + ) + print( + f"===> worldsize = {training_args.per_device_train_batch_size} rank: {dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree)}" + ) + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + shuffle=False, + num_replicas=training_args.data_parallel_degree, + rank=dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree), + drop_last=training_args.dataloader_drop_last, + ) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=data_collator, + num_workers=training_args.dataloader_num_workers, + ) + + num_update_steps_per_epoch = len(train_dataloader) // training_args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_train_epochs = training_args.max_steps // num_update_steps_per_epoch + int( + training_args.max_steps % num_update_steps_per_epoch > 0 + ) + + global_step = 1 + tr_loss = float(0) + + # hack: create dp group for distributed input data to align dygraph parallel loss. + dp_group = None + global_mesh = fleet.auto.get_mesh().get_mesh_with_dim("pp").mesh + mesh_shape = global_mesh.shape + for id in range(mesh_shape[0]): + pp_mesh = global_mesh[id] + for i in range(pp_mesh.shape[-1]): + ranks = pp_mesh[:, i] + print("dp ranks: ", ranks) + group = dist.new_group(ranks) + if dist.get_rank() in ranks: + dp_group = group + assert dp_group is not None + + model.train() + optimizer = dist.shard_optimizer(optimizer) + for epoch_idx in range(num_train_epochs): + for step, inputs in enumerate(train_dataloader): + input_ids, labels = inputs["input_ids"], inputs["labels"] + + input_id = input_ids[0][0].numpy() + label = labels[0][0].numpy() + + # hack for align dygraph parallel. + if dp_group is not None: + cur_rank = dist.get_rank() + res = [] + dist.all_gather(res, paddle.Tensor(input_ids, place=paddle.CUDAPlace(cur_rank)), group=dp_group) + input_ids = paddle.concat(res) + input_ids = dist.shard_tensor(input_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) + + res = [] + dist.all_gather(res, paddle.Tensor(labels, place=paddle.CUDAPlace(cur_rank)), group=dp_group) + labels = paddle.concat(res) + labels = dist.shard_tensor(labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) + + res = model(input_ids, labels=labels) + + # add criterion in the future. + tr_loss_step = res[0] + + if training_args.gradient_accumulation_steps > 1: + tr_loss_step /= training_args.gradient_accumulation_steps + + # do backward every micro step. + tr_loss_step.backward() + tr_loss += tr_loss_step + + if global_step % training_args.gradient_accumulation_steps == 0: + # print_grad(model) + optimizer.step() + lr_scheduler.step() + # print_param(model) + optimizer.clear_grad() + print( + f"global_step {global_step // training_args.gradient_accumulation_steps};input id {input_id}; label {label}; loss {tr_loss.numpy()} lr: {optimizer.get_lr()}" + ) + tr_loss = 0 + + if global_step // training_args.gradient_accumulation_steps >= training_args.max_steps: + break + + global_step += 1 + + +def shard_model(model): + pp_stage = 0 + for name, layer in model.named_sublayers(include_self=False): + if hasattr(layer, "ipp"): + pp_stage = layer.ipp + # print(f"name {name},pp_stage {pp_stage}==>", type(layer)) + if "embed_tokens" in name: + # embedding only support column split now. it will update in the future + shard_fn(layer, 0, [dist.Replicate(), dist.Shard(1)]) + for n in [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.qkv_proj", + "gate_proj", + "up_proj", + "gate_up_fused_proj", + ]: + if n in name: + shard_fn(layer, pp_stage, [dist.Replicate(), dist.Shard(1)]) + break + for n in ["self_attn.o_proj", "down_proj"]: + if n in name: + shard_fn(layer, pp_stage, [dist.Replicate(), dist.Shard(0)]) + break + if "lm_head" in name: + shard_fn(layer, -1, [dist.Replicate(), dist.Shard(1)]) + + +def load_model(model): + model_state_dict = model.state_dict() + state_dict = paddle.load("hand/all.pdparams") + tmp = OrderedDict() + (tmp, state_dict) = (state_dict, tmp) + for (k, v) in tmp.items(): + k = map_structure_name(k) + state_dict[k] = v + model.set_state_dict(state_dict) + assert len(model_state_dict) == len(state_dict), f"{len(model_state_dict)} vs {len(state_dict)}" + """ + print("=======model_state_dict=======") + for (k,v) in model_state_dict.items(): + print(f"{k}=>{v.shape}") + """ + print("=======state_dict=======") + for (k, v) in state_dict.items(): + assert k in model_state_dict + print(f"{k}=>{v.shape}") + + +def print_grad(model): + model_state_dict = model.state_dict() + name_mapping = {v.name: k for (k, v) in model_state_dict.items()} + for p in model.parameters(): + assert p.name in name_mapping + if p.grad is not None: + print(f"{name_mapping[p.name]} {p.name}_grad shape: {p.grad.shape} md5sum: {p.grad._md5sum()}") + + +def print_param(model): + model_state_dict = model.state_dict() + name_mapping = {v.name: k for (k, v) in model_state_dict.items()} + for p in model.parameters(): + assert p.name in name_mapping + if p.grad is not None: + print(f"{name_mapping[p.name]} {p.name} shape: {p.shape} md5sum: {p._md5sum()}") + + +def map_structure_name(k): + fs = k.split(".") + idx = int(fs[1]) + if idx == 0: + return "llama.embed_tokens.weight" + if idx == 33: + return "llama.norm.weight" + if idx == 34: + return "lm_head.weight" + else: + return f"llama.layers.{idx-1}." + ".".join(fs[2:]) + + +if __name__ == "__main__": + main() diff --git a/llm/llama/auto_parallel/run_pretrain_3D_auto.sh b/llm/llama/auto_parallel/run_pretrain_3D_auto.sh new file mode 100644 index 000000000000..46a158bdeb79 --- /dev/null +++ b/llm/llama/auto_parallel/run_pretrain_3D_auto.sh @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# just for debug + +set -x +unset CUDA_VISIBLE_DEVICES + +export FLAGS_call_stack_level=3 +export FLAGS_use_cuda_managed_memory=true + +task_name="llama_auto_dp2mp2pp2" +rm -rf output/$task_name/ +rm -rf "output/$task_name""_log" + +export SOT_LOG_LEVEL=4 +export PYTHONPATH=../../../:$PYTHONPATH +#ulimit -c unlimited +#export GLOG_v=10 + +rm -rf log_auto + +export FLAGS_embedding_deterministic=1 +export FLAGS_cudnn_deterministic=1 +export NVIDIA_TF32_OVERRIDE=0 + +python3.8 -u -m paddle.distributed.launch \ + --gpus "0,1,2,3,4,5,6,7" \ + --log_dir "auto_3d" \ + run_pretrain_3D_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir "output/$task_name" \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 2 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 1 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --pipeline_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 20000 \ + --save_steps 5000000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --logging_steps 1\ + --dataloader_num_workers 1 \ + --sharding "" \ + --eval_steps 1000000 \ + --disable_tqdm true \ + --continue_training 0\ + --recompute 0 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --parallel_mode "auto" \ + --max_grad_norm 1.0 \ diff --git a/llm/llama/auto_parallel/run_pretrain_3D_hand.py b/llm/llama/auto_parallel/run_pretrain_3D_hand.py new file mode 100644 index 000000000000..7a353d3f9bbf --- /dev/null +++ b/llm/llama/auto_parallel/run_pretrain_3D_hand.py @@ -0,0 +1,813 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPT/Llama auto parallel pretraining scripts. +""" +import os +import random +import sys +import types +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import List, Optional + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.io import DataLoader, DistributedBatchSampler + +from paddlenlp.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments +from paddlenlp.trainer.utils.reshard import NodeModelState, all_gather_state_dict +from paddlenlp.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, +) +from paddlenlp.utils.log import logger + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + parallel_mode: str = field(default="hybrid", metadata={"help": ""}) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_type: Optional[str] = field( + default="llama", metadata={"help": "Only support for llama pre-training for now."} + ) + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + use_flash_attention: bool = field( + default=False, + metadata={"help": "use_flash_attention"}, + ) + use_fused_rms_norm: bool = field( + default=False, + metadata={"help": "llama, use_fused_rms_norm"}, + ) + fuse_attention_qkv: bool = field( + default=False, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + recompute_granularity: str = field( + default="full", + metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"}, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models." + }, + ) + sequence_parallel: bool = field( + default=False, + metadata={"help": "whether to use sequence parallel"}, + ) + fuse_sequence_parallel_allreduce: bool = field( + default=False, + metadata={"help": "whether to use fuse sequence parallel allreduce"}, + ) + use_fused_rope: Optional[bool] = field( + default=False, + metadata={"help": "Enable rope fusion or not."}, + ) + no_recompute_layers: Optional[List[int]] = field( + default=None, + metadata={"help": "Specify the full transformer layers that should not be recomputed."}, + ) + pp_recompute_interval: int = field( + default=1, + metadata={ + "help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0." + }, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.data_parallel_degree + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.data_parallel_degree + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + + logger.info(tokenizer._decode(input_ids)) + + from paddlenlp.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = tokens_[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] # add + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +def create_optimizer(model, lr_scheduler, training_args): + decay_parameters = [ + p.name + for n, p in model.named_parameters() + if (not any(nd in n for nd in ["bias", "norm"])) or "llama.norm.weight" in n + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + optimizer = optimizer_cls( + learning_rate=lr_scheduler if lr_scheduler is None else lr_scheduler, + apply_decay_param_fun=apply_decay_param_fun, + parameters=model.parameters(), + weight_decay=training_args.weight_decay, + grad_clip=paddle.nn.ClipGradByGlobalNorm(training_args.max_grad_norm) + if training_args.max_grad_norm > 0 + else None, + **optimizer_kwargs, + ) + + return optimizer + + +def print_config(args, key=""): + """ + print config values + """ + logger.info("=" * 60) + if args is None: + args = args + key = "Training" + import paddlenlp + + logger.info("{:^40}".format("{} Configuration Arguments".format(key))) + logger.info("{:30}: {}".format("paddle commit id", paddle.version.commit)) + logger.info("{:30}: {}".format("paddlenlp commit id", paddlenlp.version.commit)) + + for a in dir(args): + if a[:2] != "__": # don't print double underscore methods + v = getattr(args, a) + if not isinstance(v, types.MethodType): + logger.info("{:30}: {}".format(a, v)) + + logger.info("") + + +def init_seed(seed: int = 1234, args=None): + if args is None: + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + if args is not None: + if args.use_hybrid_parallel: + from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker + + random.seed(args.seed + args.dataset_rank) + np.random.seed(args.seed + args.dataset_rank) + paddle.seed(args.seed + args.dataset_rank) + + # local_seed/ global_seed is used to control dropout in ModelParallel + local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000 + global_seed = args.seed + 100003 + args.dataset_rank + tracker = get_rng_state_tracker() + + if "global_seed" not in tracker.states_: + tracker.add("global_seed", global_seed) + if "local_seed" not in tracker.states_: + tracker.add("local_seed", local_seed) + else: + random.seed(args.seed) + np.random.seed(args.seed) + paddle.seed(args.seed) + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp")[pp_idx] + return mesh + + +def _prepare_pipeline_inputs_func(inputs): + first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + last_stage_keys = ["labels"] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.enable_linear_fused_grad_add: + from fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + init_seed(args=training_args) + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": training_args.data_parallel_degree, + "mp_degree": training_args.tensor_parallel_degree, + "pp_degree": training_args.pipeline_parallel_degree, + "sharding_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + if model_args.no_recompute_layers is not None: + model_args.no_recompute_layers.sort() + + config.use_flash_attention = model_args.use_flash_attention + config.use_fused_rms_norm = model_args.use_fused_rms_norm + config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_attention_ffn = model_args.fuse_attention_ffn + config.recompute_granularity = model_args.recompute_granularity + config.virtual_pp_degree = model_args.virtual_pp_degree + config.sequence_parallel = model_args.sequence_parallel + config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce + config.use_fused_rope = model_args.use_fused_rope + config.no_recompute_layers = model_args.no_recompute_layers + config.pp_recompute_interval = model_args.pp_recompute_interval + config.recompute_use_reentrant = model_args.recompute_use_reentrant + + config.use_recompute = training_args.recompute + config.tensor_parallel_degree = training_args.tensor_parallel_degree + config.tensor_parallel_rank = training_args.tensor_parallel_rank + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + + if model_args.continue_training: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + model = model_class.from_config(config, dtype=dtype) + + print("====type===", type(model)) + + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, _, _, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + optimizer = create_optimizer(model, lr_scheduler, training_args) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + # skip grad sync + load_model(model) + assert optimizer._dp_enable + # hack for align with auto + # optimizer._dp_enable = False + + def loss_func(loss): + return loss + # hcg = fleet.get_hybrid_communicate_group() + # group = hcg.get_data_parallel_group() + # return LossMean.apply(loss, group) + + print_config(training_args) + + # create sampler and dataloader + # each rank read (training_args.per_device_train_batch_size * training_args.data_parallel_degree) samples + print( + "dp_rank: ", dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree) + ) + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=training_args.per_device_train_batch_size, + shuffle=False, + num_replicas=training_args.data_parallel_degree, + rank=dist.get_rank() // (training_args.pipeline_parallel_degree * training_args.tensor_parallel_degree), + drop_last=training_args.dataloader_drop_last, + ) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=data_collator, + num_workers=training_args.dataloader_num_workers, + ) + + num_update_steps_per_epoch = len(train_dataloader) // training_args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_train_epochs = training_args.max_steps // num_update_steps_per_epoch + int( + training_args.max_steps % num_update_steps_per_epoch > 0 + ) + + global_step = 1 + pp_data_buffer = [] + load_model(model) + model.train() + model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func + for epoch_idx in range(num_train_epochs): + for step, inputs in enumerate(train_dataloader): + input_ids, labels = inputs["input_ids"], inputs["labels"] + print(f"===> input_ids: {input_ids._md5sum()}") + print(f"===> labels: {labels._md5sum()}") + pp_data_buffer.append(inputs) + if len(pp_data_buffer) < training_args.gradient_accumulation_steps: + continue + + pp_inputs = model._prepare_pipeline_inputs_func(pp_data_buffer) + model.micro_batch_size = training_args.per_device_train_batch_size + model.accumulate_steps = training_args.gradient_accumulation_steps + + pp_inputs = model._prepare_training(pp_inputs, optimizer, lr_scheduler) + + loss = model.forward_backward_pipeline(pp_inputs) + # hack for align with auto + # sync_grad(model) + # print_grad(model) + optimizer.step() + # print_param(model) + lr_scheduler.step() + optimizer.clear_grad() + + print(f"global_step {global_step}; loss {loss.item()}; ls {optimizer.get_lr()}") + pp_data_buffer.clear() + + if global_step >= 1: + # save_model(model) + sys.exit(0) + + global_step += 1 + + +def save_model(model): + hcg = fleet.get_hybrid_communicate_group() + dp_rank = hcg.get_data_parallel_rank() + mp_degree = hcg.get_model_parallel_world_size() + mp_rank = hcg.get_model_parallel_rank() + pp_rank = hcg.get_stage_id() + if dp_rank > 0: + return + state_dict = model.state_dict() + for (k, v) in state_dict.items(): + print(f"{k}=>{v.name} {v.shape}") + paddle.save(state_dict, f"hand/pp{pp_rank:02d}mp{mp_rank:02d}.pdparams") + group = hcg.get_model_parallel_group() + + # evenly ditribute param + node_model_state = NodeModelState() + node_model_state.add_weights(state_dict, mp_rank) + + def merge_func(k, v): + assert len(v) == mp_degree + tensor_list = [e[1] for e in v] + return merge_mp_tensor_list(k, tensor_list) + + node_model_state = node_model_state.even_distribute(group) + node_model_state = node_model_state.collapse_key().merge_items(merge_func) + + def filter_func(name): + return True + + all_state_dict = all_gather_state_dict(node_model_state.model_weights, filter_func, group) + if mp_rank > 0: + return + paddle.save(all_state_dict, f"hand/pp{pp_rank:02d}.pdparams") + group = hcg.get_pipe_parallel_group() + all_state_dict = all_gather_state_dict(all_state_dict, filter_func, group) + if pp_rank > 0: + return + paddle.save(all_state_dict, "hand/all.pdparams") + + +def merge_tensor(tensor_list, fuse_num, axis): + if fuse_num > 1: + part_list = [paddle.split(e, num_or_sections=fuse_num, axis=axis) for e in tensor_list] + fuse_list = [paddle.concat(x=e, axis=axis) for e in zip(*part_list)] + return paddle.concat(x=fuse_list, axis=axis) + else: + return paddle.concat(x=tensor_list, axis=axis) + + +def load_model(model): + hcg = fleet.get_hybrid_communicate_group() + mp_rank = hcg.get_model_parallel_rank() + pp_rank = hcg.get_stage_id() + state_dict = paddle.load(f"hand/pp{pp_rank:02d}mp{mp_rank:02d}.pdparams") + model.set_state_dict(state_dict) + + +class LossMean(PyLayer): + @staticmethod + def forward(ctx, inp, group): + with paddle.no_grad(): + inps = [] + paddle.distributed.all_gather(inps, inp, group=group) + return (inps[0] + inps[1]) / 2.0 + + @staticmethod + def backward(ctx, grad): + return grad + + +def sync_grad(model): + model_state_dict = model.state_dict() + name_mapping = {v.name: k for (k, v) in model_state_dict.items()} + for p in model.parameters(): + assert p.name in name_mapping + grad = p.grad + reduce_dp(grad) + + +def print_grad(model): + model_state_dict = model.state_dict() + name_mapping = {v.name: k for (k, v) in model_state_dict.items()} + for p in model.parameters(): + assert p.name in name_mapping + grad = p.grad + grad = merge_mp(name_mapping[p.name], grad) + print(f"{name_mapping[p.name]} {p.name}_grad shape: {grad.shape} md5sum: {grad._md5sum()}") + + +def print_param(model): + model_state_dict = model.state_dict() + name_mapping = {v.name: k for (k, v) in model_state_dict.items()} + for p in model.parameters(): + tmp = merge_mp(name_mapping[p.name], p) + print(f"{name_mapping[p.name]} {p.name} shape: {tmp.shape} md5sum: {tmp._md5sum()}") + + +def merge_mp(k, input): + hcg = fleet.get_hybrid_communicate_group() + mp_degree = hcg.get_model_parallel_world_size() + if mp_degree <= 1: + return input + else: + group = hcg.get_model_parallel_group() + with paddle.no_grad(): + inps = [] + paddle.distributed.all_gather(inps, input, group=group) + return merge_mp_tensor_list(k, inps) + + +def concat_dp(input): + hcg = fleet.get_hybrid_communicate_group() + dp_degree = hcg.get_data_parallel_world_size() + if dp_degree <= 1: + return input + else: + group = hcg.get_data_parallel_group() + return concat(input, 0, group) + + +def concat(input, axis, group): + with paddle.no_grad(): + inps = [] + paddle.distributed.all_gather(inps, input, group=group) + return paddle.concat(x=inps, axis=axis) + + +def reduce_dp(input): + hcg = fleet.get_hybrid_communicate_group() + dp_degree = hcg.get_data_parallel_world_size() + if dp_degree <= 1: + return input + else: + group = hcg.get_data_parallel_group() + with paddle.no_grad(): + paddle.distributed.all_reduce(input, group=group) + return input + + +def map_structure_name(k): + if "_layers.llama" in k: + return k + hcg = fleet.get_hybrid_communicate_group() + pp_degree = hcg.get_pipe_parallel_world_size() + if pp_degree < 2: + return k + fs = k.split(".") + idx = int(fs[1]) + if idx == 0: + return "_layers.llama.embed_tokens.weight" + if idx == 33: + return "_layers.llama.norm.weight" + if idx == 34: + return "_layers.lm_head.weight" + else: + return f"_layers.llama.layers.{idx-1}." + ".".join(fs[2:]) + + +def merge_mp_tensor_list(k, tensor_list): + # merge by col + k = map_structure_name(k) + if "self_attn.qkv_proj.weight" in k: + return merge_tensor(tensor_list, 3, 1) + elif "self_attn.qkv_proj.bias" in k: + return merge_tensor(tensor_list, 3, 0) + elif "self_attn.q_proj.weight" in k: + return merge_tensor(tensor_list, 1, 1) + elif "self_attn.k_proj.weight" in k: + return merge_tensor(tensor_list, 1, 1) + elif "self_attn.v_proj.weight" in k: + return merge_tensor(tensor_list, 1, 1) + elif "mlp.up_gate_proj.weight" in k: + return merge_tensor(tensor_list, 2, 1) + elif "mlp.up_proj.weight" in k: + return merge_tensor(tensor_list, 1, 1) + elif "mlp.gate_proj.weight" in k: + return merge_tensor(tensor_list, 1, 1) + elif "lm_head.weight" in k: + return merge_tensor(tensor_list, 1, 1) + elif "mlp.up_gate_proj.bias" in k: + return merge_tensor(tensor_list, 2, 0) + # merge by row + elif "self_attn.o_proj.weight" in k: + return merge_tensor(tensor_list, 1, 0) + elif "mlp.down_proj.weight" in k: + return merge_tensor(tensor_list, 1, 0) + elif "embed_tokens.weight" in k: + return merge_tensor(tensor_list, 1, 0) + else: + assert "norm" in k, k + # duplicate + return tensor_list[0] + + +if __name__ == "__main__": + main() diff --git a/llm/llama/auto_parallel/run_pretrain_3D_hand.sh b/llm/llama/auto_parallel/run_pretrain_3D_hand.sh new file mode 100644 index 000000000000..588c2e8707dd --- /dev/null +++ b/llm/llama/auto_parallel/run_pretrain_3D_hand.sh @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# just for debug + +set -x +unset CUDA_VISIBLE_DEVICES + +export FLAGS_call_stack_level=3 +export FLAGS_use_cuda_managed_memory=true +task_name="llama_auto_dp2mp2pp2" +rm -rf output/$task_name/ +rm -rf "output/$task_name""_log" + +export SOT_LOG_LEVEL=4 +export PYTHONPATH=../../../:$PYTHONPATH +#ulimit -c unlimited +#export GLOG_v=10 +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export FLAGS_embedding_deterministic=1 +export FLAGS_cudnn_deterministic=1 +export NVIDIA_TF32_OVERRIDE=0 +python3.8 -u -m paddle.distributed.launch \ + --gpus "0, 1,2,3,4,5,6,7" \ + --log_dir "hand_3d" \ + run_pretrain_3D_hand.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir "output/$task_name" \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 2 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 1 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --pipeline_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 20000 \ + --save_steps 5000000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --logging_steps 1\ + --dataloader_num_workers 1 \ + --sharding "" \ + --eval_steps 1000000 \ + --disable_tqdm true \ + --continue_training 0\ + --recompute 0 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --max_grad_norm 1.0 \ diff --git a/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu b/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu index 68c563b09235..6cdb31f31461 100644 --- a/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu +++ b/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu @@ -58,8 +58,9 @@ std::vector RMSLnFwd(const paddle::Tensor &x, auto place = x.place(); auto y = paddle::empty(x_shape, scale.type(), place); - auto invvar = paddle::empty({rows}, paddle::DataType::FLOAT32, place); - + auto variance_shape = x_shape; + variance_shape.pop_back(); + auto invvar = paddle::empty(variance_shape, paddle::DataType::FLOAT32, place); cuda_rms_norm(x, scale, rows, cols, epsilon, &y, &invvar); return {y, invvar}; } @@ -104,9 +105,9 @@ std::vector> RMSLnFwdInferShape( std::vector x_shape, std::vector scale_shape, float epsilon) { - int rows, cols; - GetRowsCols(x_shape, &rows, &cols); - return {x_shape, {rows}}; + auto variance_shape = x_shape; + variance_shape.pop_back(); + return {x_shape, variance_shape}; } std::vector LnFwdInferDtype(paddle::DataType x_dtype, diff --git a/paddlenlp/trainer/utils/reshard/common.py b/paddlenlp/trainer/utils/reshard/common.py index b75b134f0ff5..cc834862e299 100644 --- a/paddlenlp/trainer/utils/reshard/common.py +++ b/paddlenlp/trainer/utils/reshard/common.py @@ -217,7 +217,8 @@ def collapse(state, l): for key in state_keys: assert len(key) == 2 k, rank = key - assert len(k) == l + if isinstance(k, tuple): + assert len(k) == l if k != pre: pre = k state[k] = [] @@ -468,7 +469,8 @@ def merge(state, l): (state, tmp_state) = (tmp_state, state) state_keys = list(tmp_state.keys()) for key in state_keys: - assert len(key) == l + if isinstance(key, tuple): + assert len(key) == l v = tmp_state[key] v = sorted(v, key=lambda x: x[0]) state[key] = merge_func(key, v) diff --git a/paddlenlp/transformers/llama/__init__.py b/paddlenlp/transformers/llama/__init__.py index 9ba209fcbd39..0d0c008d3960 100644 --- a/paddlenlp/transformers/llama/__init__.py +++ b/paddlenlp/transformers/llama/__init__.py @@ -14,6 +14,7 @@ from .configuration import * from .modeling import * +from .modeling_3D_auto import * from .modeling_auto import * from .modeling_pp import * from .tokenizer import * diff --git a/paddlenlp/transformers/llama/modeling_3D_auto.py b/paddlenlp/transformers/llama/modeling_3D_auto.py new file mode 100644 index 000000000000..86ee620a15a8 --- /dev/null +++ b/paddlenlp/transformers/llama/modeling_3D_auto.py @@ -0,0 +1,1259 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Llama model""" +from __future__ import annotations + +import math +import warnings +from functools import partial +from typing import Optional, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model + +from .configuration import ( + LLAMA_PRETRAINED_INIT_CONFIGURATION, + LLAMA_PRETRAINED_RESOURCE_FILES_MAP, + LlamaConfig, +) +from .modeling import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaNTKScalingRotaryEmbedding, + LlamaRotaryEmbedding, + _expand_2d_mask, + apply_rotary_pos_emb, + build_alibi_tensor, + get_triangle_upper_mask, + is_casual_mask, + repeat_kv, + rms_norm_fused, +) + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +__all__ = [ + "LlamaForCausalLM3DAuto", +] + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp")[pp_idx] + return mesh + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +attention_cnt = 0 + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi=None, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Flash Attention now ignore attention mask + # Current Flash Attention doesn't support attn maskt + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + if alibi is not None: + attention_mask = attention_mask.cast(alibi.dtype) + alibi + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + return_softmax=output_attentions, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) + attn_weights = None + + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + global attention_cnt + """ + if attention_cnt == 0: + print(f"q_{attention_cnt} shape: {query_states.shape} md5: {query_states._md5sum()}") + """ + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + """ + if attention_cnt == 0: + print( + f"attn_weights_{attention_cnt} shape: {attn_weights.shape} local_shape: {attn_weights._local_shape} md5sum: {attn_weights._md5sum()}" + ) + """ + # then add alibi bias + if alibi is not None: + alibi = alibi.reshape([bsz, num_heads, 1, -1]) + attn_weights = attn_weights + alibi + + if list(attn_weights.shape) != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + # NOTE: we only call get_triangle_upper_mask under PP setup + # FIXME ZHUI when we use pipeline parallel, the attention_mask can be None + # we just make it triangle_upper_mask + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if list(attention_mask.shape) != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + attn_weights = attn_weights + attention_mask + """ + if attention_cnt == 0: + print( + f"attn_weights_after_add_{attention_cnt} shape: {attn_weights.shape} local_shape: {attn_weights._local_shape} md5: {attn_weights._md5sum()}" + ) + """ + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + """ + if attention_cnt == 0: + print( + f"attn_weights_after_soft_{attention_cnt} shape: {attn_weights.shape} local_shape: {attn_weights._local_shape} md5: {attn_weights._md5sum()}" + ) + """ + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + attention_cnt = attention_cnt + 1 + return (attn_output, attn_weights) if output_attentions else attn_output + + +class LlamaRMSNormAuto(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + def forward(self, hidden_states): + if self.config.use_fused_rms_norm: + tmp = rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) + print(f"rms {tmp.placements}") + return tmp + + if paddle.in_dynamic_mode(): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + + return hidden_states * self.weight + + +class LlamaMLPAuto(nn.Layer): + def __init__(self, config, ipp: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fuse_attention_ffn = config.fuse_attention_ffn + self.ipp = ipp + self.config = config + + if config.fuse_attention_ffn: + self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + """ + self.gate_up_fused_proj.weight = dist.shard_tensor( + self.gate_up_fused_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + """ + self.gate_proj.weight = dist.shard_tensor( + self.gate_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + """ + self.up_proj.weight = dist.shard_tensor( + self.up_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + """ + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + """ + + def forward(self, x): + if self.fuse_attention_ffn: + gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) + out = self.down_proj(F.silu(gate_out) * up_out) + else: + out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return out + + +class LlamaAttentionAuto(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: Optional[int] = None): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // config.num_attention_heads + + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + + self.max_position_embeddings = config.max_position_embeddings + self.seq_length = config.seq_length + + self.fuse_attention_qkv = config.fuse_attention_qkv + if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads: + raise ValueError( + f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}" + ) + + self.kv_indices = None + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + self.ipp = ipp + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if self.fuse_attention_qkv: + self.qkv_proj = nn.Linear( + self.hidden_size, + 3 * self.hidden_size, + bias_attr=False, + ) + """ + self.qkv_proj.weight = dist.shard_tensor( + self.qkv_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + """ + self.q_proj.weight = dist.shard_tensor( + self.q_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + + self.k_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + """ + self.k_proj.weight = dist.shard_tensor( + self.k_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + + self.v_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + """ + self.v_proj.weight = dist.shard_tensor( + self.v_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + """ + + self.o_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + """ + self.o_proj.weight = dist.shard_tensor( + self.o_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + """ + + if config.rope: + self._init_rope() + + self.config = config + + def _init_rope(self): + if self.config.rope_scaling_type is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + ) + elif self.config.rope_scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + ) + elif self.config.rope_scaling_type == "ntk": + self.rotary_emb = LlamaNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + ) + elif self.config.rope_scaling_type == "dynamic_ntk": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}") + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + alibi: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + # print(f"attention input md5sum {hidden_states._md5sum()}") + if self.fuse_attention_qkv: + target_shape = [0, 0, self.num_heads, 3 * self.head_dim] + mix_layer = self.qkv_proj(hidden_states) + mix_layer = paddle.reshape_(mix_layer, target_shape) + query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1) + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + + query_states = self.q_proj(hidden_states).reshape(shape=target_query_shape) + key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape) + value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + + if self.config.rope: + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # hack here, because elementwise infer spmd not support broadcast now + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.kv_indices is not None: + key_states = paddle.index_select(key_states, self.kv_indices, axis=2) + value_states = paddle.index_select(value_states, self.kv_indices, axis=2) + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute(scaled_dot_product_attention)( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class LlamaDecoderLayerAuto(nn.Layer): + def __init__(self, config, layerwise_recompute: bool = False, ipp: Optional[int] = None, idx=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttentionAuto(config, layerwise_recompute, ipp) + self.mlp = LlamaMLPAuto(config, ipp) + self.input_layernorm = LlamaRMSNormAuto(config) + self.post_attention_layernorm = LlamaRMSNormAuto(config) + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + self.ipp = ipp + self.idx = idx + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + alibi: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + """ + if self.idx == 0: + print(f"input_layernorm_{self.idx} shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}") + """ + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute(self.self_attn)( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + """ + if self.idx == 0: + print(f"att_{self.idx} shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}") + """ + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + """ + if self.idx == 0: + print( + f"post_attention_layernorm_{self.idx} shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}" + ) + """ + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + # md5 = hidden_states._md5sum() + # print(f"decoder_{self.idx} shape: {hidden_states.shape} md5sum: {md5}") + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class LlamaPretrainedModelAuto(PretrainedModel): + config_class = LlamaConfig + base_model_prefix = "llama" + pretrained_init_configuration = LLAMA_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = LLAMA_PRETRAINED_RESOURCE_FILES_MAP + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + + @classmethod + def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "LlamaModelAuto" + if "LlamaModelAuto" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "llama." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + ''' + def _init_weights(self, layer): + """Initialization hook""" + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + LlamaLMHeadAuto, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.llama.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, LlamaMLPAuto): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, LlamaAttentionAuto): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + ''' + + +@register_base_model +class LlamaModelAuto(LlamaPretrainedModelAuto): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayerAuto`] + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + """ + self.embed_tokens.weight = dist.shard_tensor( + self.embed_tokens.weight, + get_mesh(), + [dist.Replicate(), dist.Shard(1)], + ) + """ + + def get_layer_ipp(layer_index): + mesh = fleet.auto.get_mesh() + if "pp" not in mesh.dim_names: + return None + else: + pp_degree = mesh.get_dim_size("pp") + layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree) + return layer_index // layer_per_stage + + self.layers = nn.LayerList( + [ + LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i), i) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNormAuto(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + # NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel + + combined_attention_mask = dist.shard_tensor( + combined_attention_mask, get_mesh(), [dist.Shard(0), dist.Replicate()] + ) + + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # print(f"inputs_embeds: {inputs_embeds.shape} md5sum: {inputs_embeds._md5sum()}") + + # embed positions + if attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if self.config.alibi: + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length_with_past]) + else: + alibi = None + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + # NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel + position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + is_casual = is_casual_mask(attention_mask) + if is_casual and alibi is None: + attention_mask = None + hidden_states = inputs_embeds + hidden_states = dist.reshard(hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()]) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + pre_ipp = 0 + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + + if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp: + hidden_states = dist.reshard( + hidden_states, + get_mesh(decoder_layer.ipp), + [dist.Shard(0), dist.Replicate()], + ) + position_ids = dist.reshard( + position_ids, + get_mesh(decoder_layer.ipp), + [dist.Shard(0), dist.Replicate()], + ) + attention_mask = dist.reshard( + attention_mask, + get_mesh(decoder_layer.ipp), + [dist.Shard(0), dist.Replicate()], + ) + + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = recompute(decoder_layer)( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + ) + + pre_ipp = decoder_layer.ipp + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + + +loss_cnt = 0 + + +class LlamaPretrainingCriterionAuto(paddle.nn.Layer): + """ + Criterion for Llama. + It calculates the final loss. + """ + + def __init__(self, config): + + super(LlamaPretrainingCriterionAuto, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + global loss_cnt + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + # print(f"prediction_scores_{loss_cnt}: {prediction_scores.shape} md5sum: {prediction_scores._md5sum()}") + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + # print(f"masked_lm_loss_{loss_cnt}: {masked_lm_loss.shape} md5sum: {masked_lm_loss._md5sum()}") + # skip ignore_index which loss == 0 + # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32") + # TODO: solve the issue of conditional block + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss) + loss_cnt = loss_cnt + 1 + return loss + + +class LlamaLMHeadAuto(nn.Layer): + def __init__(self, config: LlamaConfig): + super(LlamaLMHeadAuto, self).__init__() + self.config = config + vocab_size = config.vocab_size + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + """ + self.weight = dist.shard_tensor( + self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ), + get_mesh(-1), + [dist.Replicate(), dist.Shard(1)], + ) + """ + + def forward(self, hidden_states, tensor_parallel_output=None): + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + # print(f"llamaout shape: {hidden_states.shape} md5sum: {hidden_states._md5sum()}") + logits = paddle.matmul(hidden_states, self.weight, transpose_y=False) + # print(f"logit {logits.dist_attr}") + return logits + + +class LlamaForCausalLM3DAuto(LlamaPretrainedModelAuto): + enable_to_static_method = True + + def __init__(self, config): + super().__init__(config) + self.config = config + + # dygraph auto_parallel do not support lazy now! + # with paddle.LazyGuard(): + # self.llama = LlamaModelAuto(config) + # self.lm_head = LlamaLMHeadAuto(config) + # self.criterion = LlamaPretrainingCriterionAuto(config) + + self.llama = LlamaModelAuto(config) + self.lm_head = LlamaLMHeadAuto(config) + self.criterion = LlamaPretrainingCriterionAuto(config) + + def get_input_embeddings(self): + return self.llama.embed_tokens + + def set_input_embeddings(self, value): + self.llama.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.llama = decoder + + def get_decoder(self): + return self.llama + + def prepare_inputs_for_generation( + self, input_ids, use_cache=False, past_key_values=None, inputs_embeds=None, **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids=None, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + input_ids.stop_gradient = True + + if not input_ids.is_dist(): + input_ids = dist.shard_tensor(input_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.llama( + input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] # [bs, seq_len, dim] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + labels.stop_gradient = True + if not labels.is_dist(): + labels = dist.shard_tensor(labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index f22c57832853..52e15505c259 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -45,11 +45,13 @@ function gpt_case_list_auto() { } function llama_case_list_auto() { - llama_auto_recompute_bs8_fp32_DP1-MP1-PP1 - llama_auto_recompute_bs16_fp32_DP2-MP1-PP1 - llama_auto_recompute_bs16_fp32_DP2-MP2-PP1 - llama_auto_recompute_bs16_fp32_DP2-MP2-PP2 - llama_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2 + llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2 + + llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1 + llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1 + llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1 + llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2 + llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2 } function gpt_case_list_auto_pir() { @@ -834,7 +836,7 @@ function gpt_auto_sp_acc_check() { echo "=========== $FUNCNAME run end ===========" } -function llama_auto_recompute_bs8_fp32_DP1-MP1-PP1() { +function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=2 @@ -900,7 +902,7 @@ function llama_auto_recompute_bs8_fp32_DP1-MP1-PP1() { echo "=========== $FUNCNAME run end ===========" } -function llama_auto_recompute_bs16_fp32_DP2-MP1-PP1() { +function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=2 @@ -966,7 +968,7 @@ function llama_auto_recompute_bs16_fp32_DP2-MP1-PP1() { echo "=========== $FUNCNAME run end ===========" } -function llama_auto_recompute_bs16_fp32_DP2-MP2-PP1() { +function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=2 @@ -1032,7 +1034,7 @@ function llama_auto_recompute_bs16_fp32_DP2-MP2-PP1() { echo "=========== $FUNCNAME run end ===========" } -function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2() { +function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=2 @@ -1098,7 +1100,7 @@ function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2() { echo "=========== $FUNCNAME run end ===========" } -function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2() { +function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=2 @@ -1165,6 +1167,73 @@ function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2() { check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } + +function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() { + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + + task_name="llama_auto_bs16_dp2mp2pp2" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_3D_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --max_seq_length 2048 \ + --hidden_size 1024 \ + --intermediate_size 3072 \ + --num_hidden_layers 8 \ + --num_attention_heads 32 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 2 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 0 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --pipeline_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --learning_rate 0.0001 \ + --min_learning_rate 0.00001 \ + --max_steps 10 \ + --save_steps 5000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --logging_steps 1 \ + --dataloader_num_workers 1 \ + --sharding "" \ + --eval_steps 1000000 \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --data_impl "mmap" \ + --parallel_mode "auto" \ + --max_grad_norm 1.0 \ + >>${log_path}/$FUNCNAME 2>&1 + loss=`cat $case_log_dir/workerlog.2 | grep 'global_step 10' | awk -F '; loss' '{print $2}' | awk -F 'lr' '{print $1}'` + ips=-1 + mem=-1 + echo "result: loss=$loss ips=$ips mem=$mem" + loss_base=9.543781280517578 + ips_base=-1 + mem_base=-1 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} + ############ case end ############ function check_result() {