Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Fix adaptation prompt CI on transformers main #1465

Merged

Conversation

younesbelkada
Copy link
Contributor

As per title, this PR makes our new workflow that runs CI on transformers main happy ! - follow up work from #1461 to cover all usecases

cc @pacman100 @BenjaminBossan

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on a fix for this issue. In general, this solution is fine with me, since it seems to solve the problem.

I had an issue, however, with understanding the llama_compute_query_states so I tried to see if I can simplify it. Please check out my comment.

On a side note, I needed to add @pytest.mark.skipif(not torch.cuda.is_bf16_supported(), reason="No BF16 support") to test_bf16_inference.


cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to simplify this function. What was especially confusing for me were the values of q_len, seq_len, and kv_seq_len.

E.g. we have:

    bsz, q_len, _ = hidden_states.size()
    value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
    seq_len = q_len
    kv_seq_len = value_states.shape[-2]

From this, it follows that q_len, seq_len, and kv_seq_len are all the same value, right?

Also, I saw that there was a bit of code duplication when it came to calculating position_ids. My simplified version tries to take this into account:

def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
    hidden_states = kwargs.get("hidden_states")
    position_ids = kwargs.get("position_ids")
    past_key_value = kwargs.get("past_key_value")
    bsz, q_len, _ = hidden_states.size()
    query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
    value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
    seq_len = q_len

    if past_key_value is not None:
        if isinstance(past_key_value, tuple):
            # for transformers <= 4.35
            seq_len += past_key_value[0].shape[-2]
        else:
            # since transformers 4.36, this is a DynamicCache instance
            seq_len += past_key_value.get_seq_length(model.layer_idx)

    # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
    if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
        # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
        cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
        return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

    # XXX
    past_seen_tokens = 0
    if position_ids is None:
        # Compute position_ids, since they are required for transformers > 4.37.2
        if past_key_value is None:
            new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
        else:
            past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
            new_cache_positions = torch.arange(
                past_seen_tokens, past_seen_tokens + q_len, device=value_states.device
            )
        position_ids = new_cache_positions.unsqueeze(0)

    cos, sin = model.rotary_emb(value_states, seq_len=q_len + past_seen_tokens, position_ids=position_ids)
    return (query_states * cos) + (llama_rotate_half(query_states) * sin)

Note that the code until # XXX is basically the old state of this function, from before the changes to transformers. After that comes the code that takes into account the new changes. I tried to keep the logic identical, but please double-check that this is true.

I tested the code with transformers 4.35.0, 4.37.2, and installed from main (4.38.0.dev0) and the tests passed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for the investigation, the proposed solution sounds great!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all tests pass, I'm fine with merging.

@younesbelkada
Copy link
Contributor Author

Thanks ! Locally they seem to all pass on transformers main ! Merging!

@younesbelkada younesbelkada merged commit 963e312 into huggingface:main Feb 18, 2024
14 checks passed
@younesbelkada younesbelkada deleted the fix-adaption-prompt-main branch February 18, 2024 13:38
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
* fix adaptation prompt CI

* add fix

* forward contrib credits from discussion

* add docstring

---------

Co-authored-by: BenjaminBossan <BenjaminBossan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants