-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference]support miniGPT4's second part dy2st #6905
Changes from 7 commits
f67018b
a5cf43a
bee52b6
406ad66
5cf2aba
8c9fff1
329c837
78d6561
ba225e3
ca434b3
f03d084
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,6 +85,17 @@ def to_static(self, output_path: str, config: dict): | |
model, output_path, skip_prune_program=True | ||
) # Note(Zhengzekang): If we prune program it may cause some inference error. | ||
|
||
@staticmethod | ||
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None): | ||
batch_size = 1 | ||
seq_len = 1 | ||
if bos_token_id is None: | ||
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.") | ||
if encoder_output is not None: | ||
batch_size = encoder_output.shape[0] | ||
seq_len = encoder_output.shape[1] | ||
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id | ||
|
||
@paddle.no_grad() | ||
def generate( | ||
self, | ||
|
@@ -109,6 +120,7 @@ def generate( | |
pre_ids=None, | ||
stop_nums=None, | ||
cache_kvs=[], | ||
inputs_embeds=None, | ||
**model_kwargs, | ||
): | ||
|
||
|
@@ -136,6 +148,7 @@ def generate( | |
top_p=top_p, | ||
cache_kvs=cache_kvs, | ||
temperature=temperature, | ||
inputs_embeds=inputs_embeds, | ||
**model_kwargs, | ||
) | ||
return ret | ||
|
@@ -213,17 +226,32 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e | |
|
||
def sample( | ||
self, | ||
input_ids, | ||
eos_token_id, | ||
input_ids=None, | ||
eos_token_id=None, | ||
cache_kvs=[], | ||
top_p=None, | ||
temperature=None, | ||
inputs_embeds=None, | ||
**model_kwargs, | ||
): | ||
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1) | ||
batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1) | ||
|
||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
elif input_ids is None and inputs_embeds is None: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
# genereate a fake input_ids according to inputs_embeds. | ||
if input_ids is None and inputs_embeds is not None: | ||
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds) | ||
if inputs_embeds is not None: | ||
batch, seq_len, hidden_dim = inputs_embeds.shape | ||
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) | ||
model_kwargs["inputs_embeds"] = inputs_embeds | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189 在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
已改,辛苦review |
||
|
||
def _forward_(**args): | ||
# cache_kvs is never empty because it is passed as a parameter in def sample. | ||
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args) | ||
return self(**model_inputs) | ||
|
||
|
@@ -294,6 +322,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): | |
) | ||
step_idx_ori += 1 | ||
encoder_output = outputs | ||
# gives it a value, means we will entered into decoder phase. | ||
model_kwargs["cache"] = 0 | ||
|
||
# decoder | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.