Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support models that doesn't output past_key_values #91

Merged
merged 1 commit into from
May 22, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented May 22, 2023

What does this PR do?

Originally pointed out in huggingface/transformers#22797 (comment) by @fullstackwebdev

By design, some models in transformers does not output past_key_value.
This is the case for a new architecture called RWKV, recently integrated in Hugging Face's transformers: huggingface/transformers#22797
For that specific architecture, it is an 'attention free' LLM that does not rely on past key value mechanism to return the cache of the model, as the tokens are always processed one by one.
This PR adds the support of these custom models, by returning None if past_key_values is not present in the model's output. The generate method should automatically take care of the rest under the hood in transformers.

To reproduce

Simply run the snippet below:

import guidance
# we use StableLM as an open example, but these issues impact all models to varying degrees
guidance.llm = guidance.llms.Transformers("RWKV/rwkv-4-169m-pile", device=0)

# we turn token healing off so that guidance acts like a normal prompting library
program = guidance('''Hello my name is {{gen max_tokens=10}}''')
print(program())

@slundberg

add safety checker to retrieve `None` in case the model has no `past_key_values`
@slundberg
Copy link
Collaborator

Thanks! I have not read up read on RWKV, so I'll take a look at that and this PR and then merge assuming it all looks good.

@slundberg
Copy link
Collaborator

slundberg commented May 22, 2023

@younesbelkada So, after digging into things a bit it seems that this PR would make RWKV function, but would disable all the Guidance acceleration that we normally get. Is there an easy way to reuse the state vector for RWKV? Basically if we have a program like below, we want to save the state at the end of the first generation and then do a batch computation that extends the state with the fixed text between the generations, and then run the second generation.

import guidance

guidance.llm = guidance.llms.Transformers("RWKV/rwkv-4-169m-pile", device=0)

program = guidance('''Hello my name is {{gen 'name' max_tokens=10 stop=" "}}, and I have a story titled "{{gen 'title'}}"''')
print(program())

We can merge as is, but it is not a good long term solution for performance-sensitive uses (meaning anytime you are waiting for a while for results).

@younesbelkada
Copy link
Contributor Author

Thank you very much for reviewing !
I believe one can retrieve the state vector from outputs.state: https://github.com/huggingface/transformers/blob/e69feab8a13cf6cbf99fd6f3ff6cbc105d2183d9/src/transformers/models/rwkv/modeling_rwkv.py#LL533C1-L533C1
However, this might require some work on guidance, so maybe better to work on a proper accelerated RWKV integration as a follow up PR

@slundberg
Copy link
Collaborator

Sounds good. I'll merge this and am happy to review any followup PR that does proper acceleration. Also, for the benefit of transformers it might be good to consider how to expose session-based state caching in a more standard way. guidance has to use monkey patching to get what we want, which is probably not a great long term solution.

@slundberg slundberg merged commit 0a64fa9 into guidance-ai:main May 22, 2023
@younesbelkada younesbelkada deleted the patch-1 branch May 22, 2023 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants