diff --git a/README.md b/README.md index ec387f8d0813..600db8333d2d 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩 | [LLama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat | | [LLama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct | | [LLama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B | +| [LLama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B | | [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat | | [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat | | [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m | @@ -85,7 +86,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩 | [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct | | [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B | | [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct | -| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B | +| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B | | [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct | | [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B | @@ -96,9 +97,6 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩 |:---------------------:|:--------:|:------------:|:--------:|:------------:|:------:|:------:|:----------:| | | | 基础能力 | 序列并行 | stage1 | stage2 | stage3 | | | Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Llama2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Llama3 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Llama3.1 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Qwen1.5 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Qwen2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -119,7 +117,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩 | 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert | |:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:| -| LLaMA | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | | Mixtral | ✅ | ✅ | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 | | Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | @@ -151,7 +149,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩 * python >= 3.8 * paddlepaddle >= 3.0.0b0 -如果您尚未安装PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。 +如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。 ### pip 安装 diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index d1415d2e9565..df96485ae06e 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -173,7 +173,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): return assignment_list -def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 try: @@ -191,7 +191,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) - logits = paddle.matmul(input_parallel, y, transpose_y=False) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) if tensor_parallel_output: return logits @@ -199,7 +199,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(x, y, transpose_y=False) + logits = paddle.matmul(x, y, transpose_y=transpose_y) return logits @@ -1267,7 +1267,8 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: for mapping in model_mappings: mapping[0] = "model." + mapping[0] mapping[1] = "llama." + mapping[1] - model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] return mappings @@ -1288,13 +1289,17 @@ def get_tensor_parallel_split_mappings(num_layers): final_actions = {} base_actions = { - "lm_head.weight": partial(fn, is_column=True), # Row Linear "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), } + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + if not config.vocab_size % config.tensor_parallel_degree == 0: base_actions.pop("lm_head.weight") base_actions.pop("embed_tokens.weight") @@ -1842,7 +1847,7 @@ def backward(ctx, grad): class LlamaLMHead(nn.Layer): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, embedding_weights=None, transpose_y=False): super(LlamaLMHead, self).__init__() self.config = config if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: @@ -1850,21 +1855,32 @@ def __init__(self, config: LlamaConfig): else: vocab_size = config.vocab_size - if vocab_size != config.vocab_size: - with get_rng_state_tracker().rng_state(): + self.transpose_y = transpose_y + if transpose_y: + if embedding_weights is not None: + self.weight = embedding_weights + else: self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], + shape=[vocab_size, config.hidden_size], dtype=paddle.get_default_dtype(), ) else: - self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - ) + if vocab_size != config.vocab_size: + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + else: + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) # Must set distributed attr for Tensor Parallel ! self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if self.weight.is_distributed: - self.weight.split_axis = 1 + # for tie_word_embeddings + self.weight.split_axis = 0 if self.transpose_y else 1 if get_env_device() == "xpu": try: from paddle_xpu.layers.nn import ( # noqa: F401 @@ -1892,22 +1908,33 @@ def forward(self, hidden_states, tensor_parallel_output=None): if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: logits = self.xpu_parallel_matmul( - hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training + hidden_states, + self.weight, + transpose_y=self.transpose_y, + tensor_parallel_output=tensor_parallel_output, + training=self.training, ) else: - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) return logits class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True + _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.config = config self.llama = LlamaModel(config) - self.lm_head = LlamaLMHead(config) + if config.tie_word_embeddings: + self.lm_head = LlamaLMHead(config, embedding_weights=self.llama.embed_tokens.weight, transpose_y=True) + self.tie_weights() + else: + self.lm_head = LlamaLMHead(config) self.criterion = LlamaPretrainingCriterion(config) def get_input_embeddings(self): diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index eaf0c1bed534..2efb06a90304 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -17,7 +17,11 @@ import paddle import paddle.distributed.fleet as fleet import paddle.nn as nn -from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, + SharedLayerDesc, +) from paddle.distributed.fleet.utils import recompute from paddlenlp.transformers.model_utils import PipelinePretrainedModel @@ -102,6 +106,13 @@ def return_args( return ret +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + class LlamaEmbeddingPipe(nn.Layer): """Extends LlamaEmbeddings to forward attention_mask through the pipeline.""" @@ -119,6 +130,10 @@ def __init__(self, config): else: self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + @property + def embedding_weight(self): + return get_attr(self.embed_tokens, "weight") + def forward(self, args): """_summary_ @@ -269,6 +284,15 @@ def forward(self, args): return self.norm(hidden_states) +class LlamaLMHeadPipe(LlamaLMHead): + def __init__(self, config, transpose_y=False): + super(LlamaLMHeadPipe, self).__init__(config, transpose_y=transpose_y) + + @property + def embedding_weight(self): + return get_attr(self, "weight") + + class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): """LlamaForPretraining adapted for pipeline parallelism. @@ -332,14 +356,35 @@ def get_hcg(): config.tensor_parallel_degree = tensor_parallel_degree config.tensor_parallel_rank = tensor_parallel_rank - self.add_sequential_layer(LayerDesc(LlamaEmbeddingPipe, config=config), "llama") + if config.tie_word_embeddings: + self.add_sequential_layer( + SharedLayerDesc( + "llama_shared_weight", LlamaEmbeddingPipe, shared_weight_attr="embedding_weight", config=config + ), + "llama", + ) + else: + self.add_sequential_layer(LayerDesc(LlamaEmbeddingPipe, config=config), "llama") + for i in range(config.num_hidden_layers): self.add_sequential_layer( LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers), f"llama.layers.{i}", ) self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama") - self.add_head(config) + if config.tie_word_embeddings: + self.add_sequential_layer( + SharedLayerDesc( + "llama_shared_weight", + LlamaLMHeadPipe, + shared_weight_attr="embedding_weight", + config=config, + **{"transpose_y": True}, + ), + "lm_head", + ) + else: + self.add_sequential_layer(LayerDesc(LlamaLMHeadPipe, config=config), "lm_head") recompute_interval = 0 @@ -366,8 +411,5 @@ def get_hcg(): # DON'T init PipelinePretrainedModel # PipelinePretrainedModel.__init__(self.super(), config=config) - def add_head(self, config): - self.add_sequential_layer(LayerDesc(LlamaLMHead, config=config), "lm_head") - def get_loss_fn(self, config): return LlamaPretrainingCriterion(config)