Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable input_embeds for ChatGLM / ChatGLMForConditionalGeneration #5775

Merged
merged 4 commits into from
May 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions paddlenlp/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,13 @@ def forward(
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2]
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
inputs_embeds = inputs_embeds.transpose([1, 0, 2])
inputs_embeds = inputs_embeds.transpose([1, 0, 2])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的一些小坑是在于input_ids是作为默认输出,在export和inference逻辑是默认使用input_ids

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其它transformer模型(比如ERNIE)对input_ids都有默认值None。同时PretrainedModel实例也都会检查,要求input_ids和embeds不同时为None,不同时有值。所以这里应该还好


if cache is None:
if self.config.pre_seq_len is not None:
Expand Down Expand Up @@ -690,6 +690,10 @@ def forward(
use_cache: bool = None,
return_dict: bool = None,
):
if input_ids is None:
assert position_ids is not None, "`position_ids` must be explicitly specified when input_ids is None."
assert attention_mask is not None, "`attention_mask` must be explicitly specified when input_ids is None."

if attention_mask is None:
attention_mask = self.get_masks(input_ids)

Expand Down Expand Up @@ -826,7 +830,7 @@ def update_model_kwargs_for_generation(

def forward(
self,
input_ids,
input_ids=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 input_ids 默认设为 None 时,需要修改下 ChatGLMModelposition_ids is Noneattention_mask is None 的分支逻辑,因为这两个分支都依赖 input_ids

1)可支持为None 的,在 input_ids is None 时将 input_ids 相关的参数改为从 input_embeds 获取。
2) 不支持为 None 的,需要显式抛出异常,说明原因。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照逻辑2)修改了,要求input_ids与attention_mask / position_ids不同时为None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line693

position_ids=None,
attention_mask=None,
cache=None,
Expand Down