-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: set cos sin in max_seq_len #203
Conversation
Thanks for the PR. Just for my understanding, what happens if we do a forward path with inputs on GPU? Wouldn't that lead to clash since we use different devices? |
We set cos sin on the device we are training on every time, so cos, sin are always on the device we train on. check whittle/whittle/models/gpt/model.py Line 301 in 748e117
|
While reviewing, I discovered that KV cache does not work at all with subnetworks because of how the rope_cache is constructed. I'll implement a fix along with setting the KV cache outside of forward (while making sure cos and sin are on the right device) |
- fix correct rope_n_elem value for subnets in KV cache and forward (when input_pos is not None) - track sub_network_rope_n_elem - in case max_seq_length changes, use sub_network_rope_n_elem
@aaronkl @rheasukthanker what do you think? Except for KV cache calls, the output should be the same as before. Now no We should probably add some KV cache test (probably in a separate issue). Although not sure if KV cache is that important for subnets - for inference, you would probably extract the subnet first. We could also disable KV cache for subnets. |
- since rope_n_elem is determined in the config along with head size, we need to pass it to the config manually
While I agree with the fixes. There are a couple things to keep in mind (and this is also the reason I left kv caching untouched earlier). I am noting these down here for completeness and we should look into this in detail further:
|
Thanks for the feedback. I agree that KV caching is inapplicable if different networks are sampled. However, the speedup could still help here in one particular use-case - subnet evaluation: It's probably easier to extract the subnet and convert it into a litgpt model. But it could be useful to have it here for flexibility. Also, if we don't want to use it for subnets at all, we could raise a warning/error when calling |
Thanks for the detailed discussion. I agree it might be actually interesting to see the effect of KV-Cache on sub-network evaluation. @gabikadlecova Could you open another issue for that so we can track it for a future PR? |
I agree with the last use case. Let's keep things as they are for now. We need more thorough tests for kv_caching though. @gabikadlecova could you create an issue for that? Also perhaps summarize the points discussed here there? |
Sure, I'll create it |
Reference Issues/PRs
#194
What does this implement/fix? Explain your changes.
Noves the cos, sin initialized to cpu at init (instead of "meta" device in litgpt , to avoid faliures with FSDP.
Minimal Example / How should this PR be tested?
Any other comments?
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the
terms of your choice.