Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Models] Add Llama-3.2 #9199

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
| [LLama2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-2-7b, meta-llama/Llama-2-7b-chat, meta-llama/Llama-2-13b, meta-llama/Llama-2-13b-chat, meta-llama/Llama-2-70b, meta-llama/Llama-2-70b-chat |
| [LLama3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3-8B, meta-llama/Meta-Llama-3-8B-Instruct, meta-llama/Meta-Llama-3-70B, meta-llama/Meta-Llama-3-70B-Instruct |
| [LLama3.1](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Meta-Llama-3.1-8B, meta-llama/Meta-Llama-3.1-8B-Instruct, meta-llama/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B-Instruct, meta-llama/Meta-Llama-3.1-405B, meta-llama/Meta-Llama-3.1-405B-Instruct, meta-llama/Llama-Guard-3-8B |
| [LLama3.2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/llama) | meta-llama/Llama-3.2-1B, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.2-3B, meta-llama/Llama-3.2-3B-Instruct, meta-llama/Llama-Guard-3-1B |
| [Baichuan](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat |
| [Baichuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/baichuan) | baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat |
| [Bloom](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/bloom) | bigscience/bloom-560m, bigscience/bloom-560m-bf16, bigscience/bloom-1b1, bigscience/bloom-3b, bigscience/bloom-7b1, bigscience/bloomz-560m, bigscience/bloomz-1b1, bigscience/bloomz-3b, bigscience/bloomz-7b1-mt, bigscience/bloomz-7b1-p3, bigscience/bloomz-7b1, bellegroup/belle-7b-2m |
Expand All @@ -85,7 +86,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
| [Qwen2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct |
| [Qwen2-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct, Qwen/Qwen2-Math-72B, Qwen/Qwen2-Math-72B-Instruct, Qwen/Qwen2-Math-RM-72B |
| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct |
| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空格变化,内容没有变化

| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B |

Expand All @@ -96,9 +97,6 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
|:---------------------:|:--------:|:------------:|:--------:|:------------:|:------:|:------:|:----------:|
| | | 基础能力 | 序列并行 | stage1 | stage2 | stage3 | |
| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一LLaMA和Llama不同版本

| Llama2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Llama3 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Llama3.1 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen1.5 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand All @@ -119,7 +117,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩

| 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
|:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
| LLaMA | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| Mixtral | ✅ | ✅ | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
Expand Down Expand Up @@ -151,7 +149,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
* python >= 3.8
* paddlepaddle >= 3.0.0b0

如果您尚未安装PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。
如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。

### pip 安装

Expand Down
61 changes: 44 additions & 17 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
return assignment_list


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
Expand All @@ -191,15 +191,15 @@
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)
logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)

Check warning on line 194 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L194

Added line #L194 was not covered by tests

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
logits = paddle.matmul(x, y, transpose_y=transpose_y)
return logits


Expand Down Expand Up @@ -1267,7 +1267,8 @@
for mapping in model_mappings:
mapping[0] = "model." + mapping[0]
mapping[1] = "llama." + mapping[1]
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])
if not config.tie_word_embeddings:
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])

mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
return mappings
Expand All @@ -1288,13 +1289,17 @@
final_actions = {}

base_actions = {
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}

if config.tie_word_embeddings:
base_actions["lm_head.weight"] = partial(fn, is_column=False)

Check warning on line 1299 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1298-L1299

Added lines #L1298 - L1299 were not covered by tests
else:
base_actions["lm_head.weight"] = partial(fn, is_column=True)

Check warning on line 1301 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1301

Added line #L1301 was not covered by tests

if not config.vocab_size % config.tensor_parallel_degree == 0:
base_actions.pop("lm_head.weight")
base_actions.pop("embed_tokens.weight")
Expand Down Expand Up @@ -1842,29 +1847,40 @@


class LlamaLMHead(nn.Layer):
def __init__(self, config: LlamaConfig):
def __init__(self, config: LlamaConfig, embedding_weights=None, transpose_y=False):
super(LlamaLMHead, self).__init__()
self.config = config
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
vocab_size = config.vocab_size // config.tensor_parallel_degree
else:
vocab_size = config.vocab_size

if vocab_size != config.vocab_size:
with get_rng_state_tracker().rng_state():
self.transpose_y = transpose_y
if transpose_y:
if embedding_weights is not None:
self.weight = embedding_weights

Check warning on line 1861 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1860-L1861

Added lines #L1860 - L1861 were not covered by tests
else:
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
shape=[vocab_size, config.hidden_size],
dtype=paddle.get_default_dtype(),
)
else:
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
if vocab_size != config.vocab_size:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(

Check warning on line 1870 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1869-L1870

Added lines #L1869 - L1870 were not covered by tests
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
else:
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
# Must set distributed attr for Tensor Parallel !
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
# for tie_word_embeddings
self.weight.split_axis = 0 if self.transpose_y else 1

Check warning on line 1883 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1883

Added line #L1883 was not covered by tests
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401
Expand Down Expand Up @@ -1892,22 +1908,33 @@

if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
hidden_states,
self.weight,
transpose_y=self.transpose_y,
tensor_parallel_output=tensor_parallel_output,
training=self.training,
)
else:
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
logits = parallel_matmul(
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
)
return logits


class LlamaForCausalLM(LlamaPretrainedModel):
enable_to_static_method = True
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.config = config

self.llama = LlamaModel(config)
self.lm_head = LlamaLMHead(config)
if config.tie_word_embeddings:
self.lm_head = LlamaLMHead(config, embedding_weights=self.llama.embed_tokens.weight, transpose_y=True)
self.tie_weights()

Check warning on line 1935 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1934-L1935

Added lines #L1934 - L1935 were not covered by tests
else:
self.lm_head = LlamaLMHead(config)
self.criterion = LlamaPretrainingCriterion(config)

def get_input_embeddings(self):
Expand Down
54 changes: 48 additions & 6 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import paddle
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
from paddle.distributed.fleet.meta_parallel import (
LayerDesc,
PipelineLayer,
SharedLayerDesc,
)
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
Expand Down Expand Up @@ -102,6 +106,13 @@
return ret


def get_attr(layer, name):
if getattr(layer, name, None) is not None:
return getattr(layer, name, None)

Check warning on line 111 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L110-L111

Added lines #L110 - L111 were not covered by tests
else:
return get_attr(layer._layer, name)

Check warning on line 113 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L113

Added line #L113 was not covered by tests


class LlamaEmbeddingPipe(nn.Layer):
"""Extends LlamaEmbeddings to forward attention_mask through the pipeline."""

Expand All @@ -119,6 +130,10 @@
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

@property
def embedding_weight(self):
return get_attr(self.embed_tokens, "weight")

Check warning on line 135 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L135

Added line #L135 was not covered by tests

def forward(self, args):
"""_summary_

Expand Down Expand Up @@ -269,6 +284,15 @@
return self.norm(hidden_states)


class LlamaLMHeadPipe(LlamaLMHead):
def __init__(self, config, transpose_y=False):
super(LlamaLMHeadPipe, self).__init__(config, transpose_y=transpose_y)

Check warning on line 289 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L289

Added line #L289 was not covered by tests

@property
def embedding_weight(self):
return get_attr(self, "weight")

Check warning on line 293 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L293

Added line #L293 was not covered by tests


class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
"""LlamaForPretraining adapted for pipeline parallelism.

Expand Down Expand Up @@ -332,14 +356,35 @@
config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank

self.add_sequential_layer(LayerDesc(LlamaEmbeddingPipe, config=config), "llama")
if config.tie_word_embeddings:
self.add_sequential_layer(

Check warning on line 360 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L359-L360

Added lines #L359 - L360 were not covered by tests
SharedLayerDesc(
"llama_shared_weight", LlamaEmbeddingPipe, shared_weight_attr="embedding_weight", config=config
),
"llama",
)
else:
self.add_sequential_layer(LayerDesc(LlamaEmbeddingPipe, config=config), "llama")

Check warning on line 367 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L367

Added line #L367 was not covered by tests

for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers),
f"llama.layers.{i}",
)
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")
self.add_head(config)
if config.tie_word_embeddings:
self.add_sequential_layer(

Check warning on line 376 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L375-L376

Added lines #L375 - L376 were not covered by tests
SharedLayerDesc(
"llama_shared_weight",
LlamaLMHeadPipe,
shared_weight_attr="embedding_weight",
config=config,
**{"transpose_y": True},
),
"lm_head",
)
else:
self.add_sequential_layer(LayerDesc(LlamaLMHeadPipe, config=config), "lm_head")

Check warning on line 387 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L387

Added line #L387 was not covered by tests

recompute_interval = 0

Expand All @@ -366,8 +411,5 @@
# DON'T init PipelinePretrainedModel
# PipelinePretrainedModel.__init__(self.super(), config=config)

def add_head(self, config):
self.add_sequential_layer(LayerDesc(LlamaLMHead, config=config), "lm_head")

def get_loss_fn(self, config):
return LlamaPretrainingCriterion(config)
Loading