Skip to content

Commit

Permalink
Enable input_embeds for ChatGLM / ChatGLMForConditionalGeneration (#5775
Browse files Browse the repository at this point in the history
)

* Enable input_embeds for ChatGLM / ChatGLMForConditionalGeneration

Fix typo and minor bugs to enable the input_embeds input rather than input_ids

* Update modeling.py

* Update modeling.py

* Update modeling.py
  • Loading branch information
parap1uie-s committed May 18, 2023
1 parent 86f1dda commit 140752d
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions paddlenlp/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,13 @@ def forward(
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2]
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
inputs_embeds = inputs_embeds.transpose([1, 0, 2])
inputs_embeds = inputs_embeds.transpose([1, 0, 2])

if cache is None:
if self.config.pre_seq_len is not None:
Expand Down Expand Up @@ -690,6 +690,10 @@ def forward(
use_cache: bool = None,
return_dict: bool = None,
):
if input_ids is None:
assert position_ids is not None, "`position_ids` must be explicitly specified when input_ids is None."
assert attention_mask is not None, "`attention_mask` must be explicitly specified when input_ids is None."

if attention_mask is None:
attention_mask = self.get_masks(input_ids)

Expand Down Expand Up @@ -826,7 +830,7 @@ def update_model_kwargs_for_generation(

def forward(
self,
input_ids,
input_ids=None,
position_ids=None,
attention_mask=None,
cache=None,
Expand Down

0 comments on commit 140752d

Please sign in to comment.