-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation pipeline fixes #26795
Conversation pipeline fixes #26795
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Also cc @lewtun - now you should actually be able to just use this pipeline in the docstrings instead of needing to do it manually in the docstrings with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. We could set a default max_length to the class rather than hardcoding it wdyt?
n = model_inputs["input_ids"].shape[1] | ||
if max_length - minimum_tokens < n: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice cleanup !
@@ -268,6 +268,10 @@ def __call__(self, conversations: Union[Conversation, List[Conversation]], num_w | |||
# Otherwise the threads will require a Conversation copy. | |||
# This will definitely hinder performance on GPU, but has to be opted | |||
# in because of this BC change. | |||
if isinstance(conversations, list) and isinstance(conversations[0], dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we update the doc to mention the type of input we expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- assuming the input documentation is updated :)
conversation = model_inputs.pop("conversation") | ||
generate_kwargs["max_length"] = max_length | ||
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs: | ||
generate_kwargs["max_new_tokens"] = 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
It should be safe to add this default, as generate
should throw a warning if the generation goes beyond what's supported by the model.
d9d9dea
to
89e48f6
Compare
* Adjust length limits and allow naked conversation list inputs * Adjust length limits and allow naked conversation list inputs * Maybe use a slightly more reasonable limit than 1024 * Skip tests for old models that never supported this anyway * Cleanup input docstrings * More docstring cleanup + skip failing TF test * Make fixup
This PR makes a couple of fixes to
ConversationalPipeline
to make it a lot easier to use:Conversation
class is quite hard for users to discover, and this is a lot more intuitive.max_length
because very few models set this parameter, and so it's almost always the defaultPretrainedConfig
value of 20, which is very low. Before this change, most calls toConversationalPipeline
produced no output or unnecessarily truncated the input because this limit was hit. We change the pipeline to usemax_new_tokens
instead, which is more modern.cc @ArthurZucker for pipeline review and @gante if he has any comments about setting the generation parameters properly!