-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add kv_cache to LLM #244
add kv_cache to LLM #244
Conversation
@vchiley Cached kv values shift the positions as well. Maybe you want to shift the position embeddings in the following?
Compare with this in HF https://github.com/huggingface/transformers/blob/60d51ef5123d949fd8c59cd4d3254e711541d278/src/transformers/models/gpt2/modeling_gpt2.py#L801 In our fork of mosaic models, we have the kv cache and the relevant part looks like the following: if past_key_values is None:
past_key_values = [None] * self.cfg.n_layers
past_position = 0
else:
assert len(past_key_values) == self.cfg.n_layers
# get the key tensor whose spec should be (batch, seq, n_head, head_dim), and
# collect the `seq`, so that we shift the position embedding later.
past_position = past_key_values[0][0].size(1)
tok_emb = self.transformer.wte(input_ids) # type: ignore
if self.alibi:
x = tok_emb
else:
if S + past_position > self.cfg.max_seq_len:
raise ValueError(
f'Cannot forward input with past sequence length {past_position} and current sequence length '
f'{S + 1}, this model only supports total sequence length <= {self.cfg.max_seq_len}.'
)
pos = torch.arange(past_position, S + past_position, dtype=torch.long,
device=input_ids.device).unsqueeze(0)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = tok_emb + pos_emb |
f8a03e7
to
26431f7
Compare
68afe03
to
aaf0658
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, can you train a model and make sure nothing is broken?
Could you explain the reason for separating out query_padding_mask? |
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
@dskhudia this formulates a generic attn fn |
Note: we should have a conversation about if all |
This pr include past_key_values (ie kv_cache) in the LLM so that inference can be accelerated.
We also become explicit about how we apply
padding_mask
for querys and keys.Shoutout: @dakinggg for working through some of this with me.
cc @dskhudia @alextrott16 @samhavens for after training / inference