Skip to content

Commit

Permalink
Support Llama3 (#8315)
Browse files Browse the repository at this point in the history
* support llama-3

* Add llama-3 tokenizer

* fix for llama3
  • Loading branch information
ZHUI authored Apr 23, 2024
1 parent d4062e5 commit 590cee9
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 12 deletions.
3 changes: 2 additions & 1 deletion llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Llama3Tokenizer,
LlamaTokenizer,
)
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -232,7 +233,7 @@ def neft_post_hook(module, input, output):
if tokenizer.chat_template is not None:
data_args.eval_with_do_generation = False

if isinstance(tokenizer, LlamaTokenizer):
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id

if data_args.dataset_name_or_path is None:
Expand Down
21 changes: 14 additions & 7 deletions paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,20 @@ def _get_tokenizer_class_from_config(cls, pretrained_model_name_or_path, config_
init_class = init_kwargs.pop("tokenizer_class", None)

if init_class:
class_name = cls._name_mapping[init_class]
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
tokenizer_class = getattr(import_class, init_class)
if use_fast:
fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name)
tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class
return tokenizer_class
if init_class in cls._name_mapping:
class_name = cls._name_mapping[init_class]
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
tokenizer_class = getattr(import_class, init_class)
if use_fast:
fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name)
tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class
return tokenizer_class
else:
import_class = import_module("paddlenlp.transformers")
tokenizer_class = getattr(import_class, init_class, None)
assert tokenizer_class is not None, f"Can't find tokenizer {init_class}"
return tokenizer_class

# If no `init_class`, we use pattern recognition to recognize the tokenizer class.
else:
# TODO: Potential issue https://github.com/PaddlePaddle/PaddleNLP/pull/3786#discussion_r1024689810
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
num_key_value_heads=None,
initializer_range=0.02,
rms_norm_eps=1e-6,
rope_theta=10000.0,
use_cache=True,
use_recompute=False,
recompute_granularity="full",
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(

self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta

self.use_cache = use_cache
self.use_recompute = use_recompute
Expand Down
9 changes: 6 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,24 +813,28 @@ def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "ntk":
self.rotary_emb = LlamaNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
elif self.config.rope_scaling_type == "dynamic_ntk":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.config.rope_scaling_factor,
base=self.config.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}")
Expand Down Expand Up @@ -903,6 +907,7 @@ def forward(
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if self.reshard_layer is not None:
if self.sequence_parallel:
assert self.seq_length % self.config.sep_parallel_degree == 0
Expand Down Expand Up @@ -1027,7 +1032,6 @@ def forward(
value_states = paddle.concat([past_key_value[1], value_states], axis=1)

past_key_value = (key_states, value_states) if use_cache else None

if self.kv_indices is not None:
key_states = paddle.index_select(key_states, self.kv_indices, axis=2)
value_states = paddle.index_select(value_states, self.kv_indices, axis=2)
Expand All @@ -1036,7 +1040,7 @@ def forward(
# repeat k/v heads if n_kv_heads < n_heads
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
paddle_version = float(paddle.__version__[:3])
if (paddle_version != 0.0) and (paddle_version <= 2.6):
if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

Expand Down Expand Up @@ -1560,7 +1564,6 @@ def forward(
else:
attention_mask = attention_mask.astype("bool")
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down
Loading

0 comments on commit 590cee9

Please sign in to comment.