Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: set cos sin in max_seq_len #203

Merged
merged 5 commits into from
Dec 11, 2024
Merged

fix: set cos sin in max_seq_len #203

merged 5 commits into from
Dec 11, 2024

Conversation

rheasukthanker
Copy link
Collaborator

Reference Issues/PRs

#194

What does this implement/fix? Explain your changes.

Noves the cos, sin initialized to cpu at init (instead of "meta" device in litgpt , to avoid faliures with FSDP.

Minimal Example / How should this PR be tested?

Any other comments?


By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the
terms of your choice.

@rheasukthanker rheasukthanker changed the title set cos sin in max_seq_len fix:set cos sin in max_seq_len Dec 6, 2024
@rheasukthanker rheasukthanker changed the title fix:set cos sin in max_seq_len fix: set cos sin in max_seq_len Dec 6, 2024
@aaronkl
Copy link
Collaborator

aaronkl commented Dec 8, 2024

Thanks for the PR. Just for my understanding, what happens if we do a forward path with inputs on GPU? Wouldn't that lead to clash since we use different devices?

@rheasukthanker
Copy link
Collaborator Author

rheasukthanker commented Dec 8, 2024

Thanks for the PR. Just for my understanding, what happens if we do a forward path with inputs on GPU? Wouldn't that lead to clash since we use different devices?

We set cos sin on the device we are training on every time, so cos, sin are always on the device we train on. check

cos, sin = self.rope_cache(

@gabikadlecova
Copy link
Collaborator

While reviewing, I discovered that KV cache does not work at all with subnetworks because of how the rope_cache is constructed. I'll implement a fix along with setting the KV cache outside of forward (while making sure cos and sin are on the right device)

- fix correct rope_n_elem value for subnets in KV cache and forward (when input_pos is not None)
- track sub_network_rope_n_elem
- in case max_seq_length changes, use sub_network_rope_n_elem
@gabikadlecova gabikadlecova removed their request for review December 9, 2024 14:16
@gabikadlecova
Copy link
Collaborator

@aaronkl @rheasukthanker what do you think? Except for KV cache calls, the output should be the same as before. Now no rope_cache calls occur in forward.

We should probably add some KV cache test (probably in a separate issue). Although not sure if KV cache is that important for subnets - for inference, you would probably extract the subnet first. We could also disable KV cache for subnets.

- since rope_n_elem is determined in the config along with head size, we need to pass it to the config manually
@gabikadlecova gabikadlecova linked an issue Dec 9, 2024 that may be closed by this pull request
@rheasukthanker
Copy link
Collaborator Author

rheasukthanker commented Dec 11, 2024

While reviewing, I discovered that KV cache does not work at all with subnetworks because of how the rope_cache is constructed. I'll implement a fix along with setting the KV cache outside of forward (while making sure cos and sin are on the right device)

While I agree with the fixes. There are a couple things to keep in mind (and this is also the reason I left kv caching untouched earlier). I am noting these down here for completeness and we should look into this in detail further:

  1. The way kv_caching works (https://neptune.ai/blog/transformers-key-value-caching), we need KV sizes (dependent on num_heads, head_size, query groups), we can use a kv_cache only for a fixed dense model (we need the KV sizes for subsequent tokens to match). Since the supernet, in case of sampling subnetworks is adaptive, KV sizes are dynamic depending upon the network sampled at inference, hence a supernet with different networks sampled during inference, renders KV caching inapplicable
  2. KV caching is only used for inference time gains (ie. for deployed models). Since we can basically convert a dense model searched by whittle directly into a litgpt model and use the KV cache there, I don't think there is real utility in adapting KV caching for whittle since we are more training/finetuning focussed.
  3. KV Caching for a scenario wherein num heads/query groups/head sizes are different per layer also fails, again since KV caching needs to be implemented in layer-specific manner in this case https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L574. Again, this we don't support layer specific dimensions anymore, this is also not an issue currently.

@gabikadlecova
Copy link
Collaborator

While reviewing, I discovered that KV cache does not work at all with subnetworks because of how the rope_cache is constructed. I'll implement a fix along with setting the KV cache outside of forward (while making sure cos and sin are on the right device)

While I agree with the fixes. There are a couple things to keep in mind (and this is also the reason I left kv caching untouched earlier). I am noting these down here for completeness and we should look into this in detail further:

  1. The way kv_caching works (https://neptune.ai/blog/transformers-key-value-caching), we need KV sizes (dependent on num_heads, head_size, query groups), we can use a kv_cache only for a fixed dense model (we need the KV sizes for subsequent tokens to match). Since the supernet, in case of sampling subnetworks is adaptive, KV sizes are dynamic depending upon the network sampled at inference, hence a supernet with different networks sampled during inference, renders KV caching inapplicable
  2. KV caching is only used for inference time gains (ie. for deployed models). Since we can basically convert a dense model searched by whittle directly into a litgpt model and use the KV cache there, I don't think there is real utility in adapting KV caching for whittle since we are more training/finetuning focussed.
  3. KV Caching for a scenario wherein num heads/query groups/head sizes are different per layer also fails, again since KV caching needs to be implemented in layer-specific manner in this case https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L574. Again, this we don't support layer specific dimensions anymore, this is also not an issue currently.

Thanks for the feedback. I agree that KV caching is inapplicable if different networks are sampled. However, the speedup could still help here in one particular use-case - subnet evaluation:
a) sample a subnetwork, clean the new KV cache
b) evaluate on a task that requires multiple token generations
c) repeat

It's probably easier to extract the subnet and convert it into a litgpt model. But it could be useful to have it here for flexibility. Also, if we don't want to use it for subnets at all, we could raise a warning/error when calling forward with input_pos not None.

@rheasukthanker rheasukthanker merged commit 015b7fe into main Dec 11, 2024
9 checks passed
@rheasukthanker rheasukthanker deleted the cos-sin-fix branch December 11, 2024 14:36
@aaronkl
Copy link
Collaborator

aaronkl commented Dec 11, 2024

Thanks for the detailed discussion. I agree it might be actually interesting to see the effect of KV-Cache on sub-network evaluation. @gabikadlecova Could you open another issue for that so we can track it for a future PR?

@rheasukthanker
Copy link
Collaborator Author

While reviewing, I discovered that KV cache does not work at all with subnetworks because of how the rope_cache is constructed. I'll implement a fix along with setting the KV cache outside of forward (while making sure cos and sin are on the right device)

While I agree with the fixes. There are a couple things to keep in mind (and this is also the reason I left kv caching untouched earlier). I am noting these down here for completeness and we should look into this in detail further:

  1. The way kv_caching works (https://neptune.ai/blog/transformers-key-value-caching), we need KV sizes (dependent on num_heads, head_size, query groups), we can use a kv_cache only for a fixed dense model (we need the KV sizes for subsequent tokens to match). Since the supernet, in case of sampling subnetworks is adaptive, KV sizes are dynamic depending upon the network sampled at inference, hence a supernet with different networks sampled during inference, renders KV caching inapplicable
  2. KV caching is only used for inference time gains (ie. for deployed models). Since we can basically convert a dense model searched by whittle directly into a litgpt model and use the KV cache there, I don't think there is real utility in adapting KV caching for whittle since we are more training/finetuning focussed.
  3. KV Caching for a scenario wherein num heads/query groups/head sizes are different per layer also fails, again since KV caching needs to be implemented in layer-specific manner in this case https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L574. Again, this we don't support layer specific dimensions anymore, this is also not an issue currently.

Thanks for the feedback. I agree that KV caching is inapplicable if different networks are sampled. However, the speedup could still help here in one particular use-case - subnet evaluation: a) sample a subnetwork, clean the new KV cache b) evaluate on a task that requires multiple token generations c) repeat

It's probably easier to extract the subnet and convert it into a litgpt model. But it could be useful to have it here for flexibility. Also, if we don't want to use it for subnets at all, we could raise a warning/error when calling forward with input_pos not None.

I agree with the last use case. Let's keep things as they are for now. We need more thorough tests for kv_caching though. @gabikadlecova could you create an issue for that? Also perhaps summarize the points discussed here there?

@gabikadlecova
Copy link
Collaborator

Sure, I'll create it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

calling reset_parameters after init leads to crash
3 participants