-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CI
] Fix adaptation prompt CI on transformers main
#1465
[CI
] Fix adaptation prompt CI on transformers main
#1465
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on a fix for this issue. In general, this solution is fine with me, since it seems to solve the problem.
I had an issue, however, with understanding the llama_compute_query_states
so I tried to see if I can simplify it. Please check out my comment.
On a side note, I needed to add @pytest.mark.skipif(not torch.cuda.is_bf16_supported(), reason="No BF16 support")
to test_bf16_inference
.
|
||
cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids) | ||
return (query_states * cos) + (llama_rotate_half(query_states) * sin) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to simplify this function. What was especially confusing for me were the values of q_len
, seq_len
, and kv_seq_len
.
E.g. we have:
bsz, q_len, _ = hidden_states.size()
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
seq_len = q_len
kv_seq_len = value_states.shape[-2]
From this, it follows that q_len
, seq_len
, and kv_seq_len
are all the same value, right?
Also, I saw that there was a bit of code duplication when it came to calculating position_ids
. My simplified version tries to take this into account:
def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
seq_len = q_len
if past_key_value is not None:
if isinstance(past_key_value, tuple):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)
# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
# XXX
past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(
past_seen_tokens, past_seen_tokens + q_len, device=value_states.device
)
position_ids = new_cache_positions.unsqueeze(0)
cos, sin = model.rotary_emb(value_states, seq_len=q_len + past_seen_tokens, position_ids=position_ids)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)
Note that the code until # XXX
is basically the old state of this function, from before the changes to transformers. After that comes the code that takes into account the new changes. I tried to keep the logic identical, but please double-check that this is true.
I tested the code with transformers 4.35.0, 4.37.2, and installed from main (4.38.0.dev0) and the tests passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for the investigation, the proposed solution sounds great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all tests pass, I'm fine with merging.
Thanks ! Locally they seem to all pass on transformers main ! Merging! |
* fix adaptation prompt CI * add fix * forward contrib credits from discussion * add docstring --------- Co-authored-by: BenjaminBossan <BenjaminBossan@users.noreply.github.com>
As per title, this PR makes our new workflow that runs CI on transformers main happy ! - follow up work from #1461 to cover all usecases
cc @pacman100 @BenjaminBossan