Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation pipeline fixes #26795

Merged
merged 7 commits into from
Oct 16, 2023
Merged

Conversation pipeline fixes #26795

merged 7 commits into from
Oct 16, 2023

Conversation

Rocketknight1
Copy link
Member

This PR makes a couple of fixes to ConversationalPipeline to make it a lot easier to use:

  • Inputs can now just be conversations in standard list-of-dicts format. I think the Conversation class is quite hard for users to discover, and this is a lot more intuitive.
  • We no longer read max_length because very few models set this parameter, and so it's almost always the default PretrainedConfig value of 20, which is very low. Before this change, most calls to ConversationalPipeline produced no output or unnecessarily truncated the input because this limit was hit. We change the pipeline to use max_new_tokens instead, which is more modern.

cc @ArthurZucker for pipeline review and @gante if he has any comments about setting the generation parameters properly!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@Rocketknight1
Copy link
Member Author

Also cc @lewtun - now you should actually be able to just use this pipeline in the docstrings instead of needing to do it manually in the docstrings with text-generation and apply_chat_template!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. We could set a default max_length to the class rather than hardcoding it wdyt?

n = model_inputs["input_ids"].shape[1]
if max_length - minimum_tokens < n:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice cleanup !

@@ -268,6 +268,10 @@ def __call__(self, conversations: Union[Conversation, List[Conversation]], num_w
# Otherwise the threads will require a Conversation copy.
# This will definitely hinder performance on GPU, but has to be opted
# in because of this BC change.
if isinstance(conversations, list) and isinstance(conversations[0], dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update the doc to mention the type of input we expect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- assuming the input documentation is updated :)

conversation = model_inputs.pop("conversation")
generate_kwargs["max_length"] = max_length
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
generate_kwargs["max_new_tokens"] = 256
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

It should be safe to add this default, as generate should throw a warning if the generation goes beyond what's supported by the model.

@Rocketknight1 Rocketknight1 force-pushed the conversation_pipeline_fixes branch from d9d9dea to 89e48f6 Compare October 16, 2023 15:17
@Rocketknight1 Rocketknight1 merged commit 14b04b4 into main Oct 16, 2023
3 checks passed
@Rocketknight1 Rocketknight1 deleted the conversation_pipeline_fixes branch October 16, 2023 16:27
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Adjust length limits and allow naked conversation list inputs

* Adjust length limits and allow naked conversation list inputs

* Maybe use a slightly more reasonable limit than 1024

* Skip tests for old models that never supported this anyway

* Cleanup input docstrings

* More docstring cleanup + skip failing TF test

* Make fixup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants