Skip to content

Commit

Permalink
support qwen2 bf16/wint8
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Aug 7, 2024
1 parent bc75f59 commit a6bde28
Show file tree
Hide file tree
Showing 4 changed files with 963 additions and 1 deletion.
53 changes: 52 additions & 1 deletion llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,9 @@ def _preprocess(self, source):
# alibi encoder
alibi_slopes = get_alibi_slopes(self.model_config.n_head)
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")

arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype)
alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder

if self.model_config.tensor_parallel_degree > 1:
block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree
alibi = alibi[
Expand Down Expand Up @@ -1380,6 +1380,32 @@ def create_predictor(
dtype=predictor_args.dtype,
)
model.eval()
elif "qwen2" in config.architectures[0].lower():
if predictor_args.block_attn:
config.max_seq_len = predictor_args.total_max_length
config.block_size = predictor_args.block_size
from paddlenlp.experimental.transformers import (
Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel,
)

model = Qwen2InferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
else:
from paddlenlp.experimental.transformers import (
Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel,
)

model = Qwen2InferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
)
model.eval()
elif "qwen" in config.architectures[0].lower():
if model_args.model_type == "qwen-img2txt":
# we use qwen for img2txt.
Expand All @@ -1405,6 +1431,16 @@ def create_predictor(

elif predictor_args.mode == "static":
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
config.quant_type = predictor_args.quant_type
config.cachekv_int8_type = predictor_args.cachekv_int8_type

if config.quantization_config.quant_type is not None:
predictor_args.quant_type = config.quantization_config.quant_type
config.quant_type = config.quantization_config.quant_type
if "c8" in config.quant_type:
predictor_args.cachekv_int8_type = "static"
config.cachekv_int8_type = "static"

if "llama" in config.architectures[0].lower():
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
Expand Down Expand Up @@ -1471,6 +1507,21 @@ def create_predictor(
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "qwen2" in config.architectures[0].lower():
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.total_max_length
from paddlenlp.experimental.transformers import (
Qwen2ForCausalLMBlockInferenceModel as Qwen2InferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
Qwen2ForCausalLMInferenceModel as Qwen2InferenceModel,
)
cache_kvs_shape = Qwen2InferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)

elif "qwen" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
QWenForCausalLMInferenceModel,
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .llama import *
from .opt import *
from .qwen import *
from .qwen2 import *

Check warning on line 23 in paddlenlp/experimental/transformers/__init__.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/__init__.py#L23

Added line #L23 was not covered by tests
15 changes: 15 additions & 0 deletions paddlenlp/experimental/transformers/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling import *
Loading

0 comments on commit a6bde28

Please sign in to comment.