Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference]support miniGPT4's second part dy2st #6905

Merged
merged 11 commits into from
Sep 7, 2023
29 changes: 20 additions & 9 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,16 @@ class PredictorArgument:
inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"})
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
max_batch_size: int = field(default=None, metadata={"help": "The max batch size of data during serving."})
benchmark: bool = field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
benchmark: bool = (
field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
),
)
llm_for_img2txt: bool = field(
default=False, metadata={"help": "whether this llm model is used for img2txt, such as miniGPT4, blip2."}
)


Expand Down Expand Up @@ -571,13 +576,19 @@ def create_predictor(
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
if "llama" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel,
)
if predictor_args.llm_for_img2txt:
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
LlamaForMiniGPT4InferenceModel as LlamaInferenceModel,

)
else:
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
model = LlamaForCausalLMInferenceModel.from_pretrained(
model = LlamaInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down
33 changes: 31 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def to_static(self, output_path: str, config: dict):
model, output_path, skip_prune_program=True
) # Note(Zhengzekang): If we prune program it may cause some inference error.

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

@paddle.no_grad()
def generate(
self,
Expand All @@ -109,6 +120,7 @@ def generate(
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**model_kwargs,
):

Expand Down Expand Up @@ -136,6 +148,7 @@ def generate(
top_p=top_p,
cache_kvs=cache_kvs,
temperature=temperature,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return ret
Expand Down Expand Up @@ -213,17 +226,32 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e

def sample(
self,
input_ids,
eos_token_id,
input_ids=None,
eos_token_id=None,
cache_kvs=[],
top_p=None,
temperature=None,
inputs_embeds=None,
**model_kwargs,
):
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1)
batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1)

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

# genereate a fake input_ids according to inputs_embeds.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
model_kwargs["inputs_embeds"] = inputs_embeds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189

在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189

在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。

已改,辛苦review


def _forward_(**args):
# cache_kvs is never empty because it is passed as a parameter in def sample.
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

Expand Down Expand Up @@ -294,6 +322,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
)
step_idx_ori += 1
encoder_output = outputs
# gives it a value, means we will entered into decoder phase.
model_kwargs["cache"] = 0

# decoder
Expand Down
120 changes: 119 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from paddlenlp.transformers.model_utils import register_base_model

__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel"]
__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForminiGPT4InferenceModel"]


class FusedLlamaRMSNorm(nn.Layer):
Expand Down Expand Up @@ -165,6 +165,7 @@ def forward(
return_dict=False,
**kwargs,
):
# kwargs["cache"] is used used to distinguish between encoder and decoder phase.
past_key_values = kwargs.get("cache", None)
is_decoder = past_key_values is not None

Expand Down Expand Up @@ -345,14 +346,19 @@ def prepare_inputs_for_generation(
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
cache = kwargs.get("cache", None)
inputs_embeds = kwargs.get("inputs_embeds", None)
if cache is not None:
input_ids = tgt_ids
position_ids = tgt_pos
attention_mask = (tgt_generation_mask - 1) * 1e4
# make inputs_embeds be none in decoder phase.
# in forward function, it will be assigned according to input_ids.
inputs_embeds = None
else:
attention_mask = (attention_mask - 1) * 1e4
model_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_kvs": cache_kvs,
Expand Down Expand Up @@ -432,3 +438,115 @@ def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


class LlamaForminiGPT4InferenceModel(LlamaForCausalLMInferenceModel):
"""
This class is 99% like LlamaForCausalLMInferenceModel.
Used only for miniGPT4's second part.
"""

# This function corresponds to miniGPT4's second part, only used in miniGPT4.
@paddle.no_grad()
def generate_text_with_image_features(
self,
image_features: paddle.Tensor,
first_input_ids: paddle.Tensor,
second_input_ids: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:

first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

outputs = self.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

# rewrite to_static function in generation_utils.py
def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path)