Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RWKV-4 #22797

Merged
merged 31 commits into from
May 9, 2023
Merged

Add RWKV-4 #22797

merged 31 commits into from
May 9, 2023

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Apr 16, 2023

What does this PR do?

This PR is a draft and while there is a working implementation of the model, there is still a lot to do :-)

This PR adds the RWKV model from BlinkDL/RWKV-LM which is a RNN-like Transformers: it has an attention layer and a feed-forward, but the attention is linear and can be expressed recurrently (more details coming in the doc page of the model).

Here is a code snippet to play with the model:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("sgugger/rwkv-7b-pile", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-7b-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=400, top_p=0.8, do_sample=True)
print(tokenizer.decode(output[0].tolist()))

To use the chat models (called raven):

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_id = "ybelkada/rwkv-raven-1b5"
model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id)

question = "Tell me about ravens"
prompt = f"### Instruction: {question}\n### Response:"

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=100)

print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

Fixes #20737
Fixes #17230

TODO:

  • Write documentation of the model explaining the linear attention and the recurrent formulas in the code
  • Make the model compatible with generate
  • Add output_attentions/output_hidden_states API
  • Convert mode models and check conversion script is compatible
  • Tweak CUDA kernels for state to use the state for init
  • Make tests that pass
  • Add attention mask to be able to batch sentences (might be in a followup PR)

cc @ArthurZucker and @younesbelkada

@ArthurZucker ArthurZucker mentioned this pull request Apr 21, 2023
4 tasks
# TODO: maybe jit, otherwise move inside forward
def extract_key_value(self, hidden, state=None):
# Mix hidden with the previous timestep to produce key, value, receptance
shifted = self.time_shift(hidden) if state is None or hidden.size(1) != 1 else state[1][:, :, self.layer_id]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this shift mistakenly drops the previous hidden in state when provided with a sequence of length larger than 1. In the case of state is not None and hidden.size(1) != 1, it should cut the last token and prepend with the token from state.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks for the pointer!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 28, 2023

The documentation is not available anymore as the PR was closed or merged.

sgugger and others added 14 commits April 27, 2023 21:06
- fix common tests
- fix configuraion default values
- add CI test for checking state computation
- fix some CI tests
- fix config docstring
- fix failing tests
- add output_attention / output_hidden_states
- override test_initialization
- fix failing CIs
- fix sharded case
- add new arguments
@younesbelkada younesbelkada requested a review from amyeroberts May 4, 2023 15:26
@younesbelkada
Copy link
Contributor

IMO the model is in a nice shape! Would love to have a round of review before I transfer the weights on the proper organization!

@@ -93,15 +93,20 @@


class ConfigTester(object):
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this change is added here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in RWKV from my understanding there is no notion of attention heads. This default test expects to always have num_attention_heads so I decided to make it slightly modular to accept custom common_properties. I thought as we might have models like that in the future maybe it's a good idea to make it slightly modular.
Happy to revert it / maybe override that test if you think it's better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
Copy link
Contributor

@younesbelkada younesbelkada May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a "mock" attention mask here so that pipeline won't fail complaining that the attention mask is outputted by the tokenizer and not used by the model. As we want to add the attention mask support anyway in the future, I thought it's the simplest solution now. To reproduce:

from transformers import pipeline

model_id = "ybelkada/rwkv-4-169m-pile"
prompt = "Hello"
pipe = pipeline("text-generation", model=model_id)
print(pipe(prompt, max_length=10))

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sgugger for taking the lead on this! Learned a lot 🔥

@BlinkDL
Copy link

BlinkDL commented May 9, 2023

@younesbelkada In README.md

The name should be "Bo Peng" (Peng is the surname) instead of "Peng Bo" :)

forward_func = rwkv_cuda_kernel.forward_with_state_bf16
else:
forward_func = rwkv_cuda_kernel.forward_with_state
# TODO: update CUDA kernel so it uses the initial state provided here.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this todo have been done?


@staticmethod
# g stands for grad
def backward(ctx, g_output):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any plan on supporting gradients on states? It would make chaining wkvs in training possible, getting rid of the seqlen limitation. It will also match the _with_state variant of WKV forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense, however I would advocate to do that in a follow up PR to at least unlock the model addition for users that already want to use the model with transformers

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I help on that later?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure ! You are more than welcome to help us on that

time_decay, time_first, key, value, output = ctx.saved_tensors
# The CUDA kernel will fill those tensors.
g_time_decay = torch.empty_like(
time_decay,
Copy link

@Blealtan Blealtan May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read it right, time_decay/time_first is of shape (C, ) while the CUDA kernel requires gw and gu of shape (B, C) (see wkv_cuda.cu:140-141). It may cause VRAM overflow as well as wrong results. I didn't set up the environment to do the test, but I suspect the current g_time_decay/first after the summation in lines 192-193 will unexpectedly become scalars, which can verify my guess.

@sgugger sgugger merged commit b4d4d6f into main May 9, 2023
@sgugger sgugger deleted the add_rwkv branch May 9, 2023 17:04
@YovaKem
Copy link

YovaKem commented May 11, 2023

hi @sgugger, thanks A TON for this merge! I am trying to train a new model of type and facing the following error:

Traceback (most recent call last):
  File "train.py", line 229, in <module>
    main(model_args, data_args, training_args)
  File "train.py", line 193, in main
    trainer.train()
  File "transformers/src/transformers/trainer.py", line 1664, in train
    return inner_training_loop(
  File "transformers/src/transformers/trainer.py", line 1940, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "transformers/src/transformers/trainer.py", line 2753, in training_step
    loss.backward()
  File ".conda/envs/rwkv-eval-3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File ".conda/envs/rwkv-eval-3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File ".conda/envs/rwkv-eval-3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
TypeError: backward() takes 2 positional arguments but 3 were given

From what I can see, the backward function of RwkvLinearAttentionBackward does not mention a g_state - should gradients be computed for the state, I guess not? Any pointers as to how I can resolve this will be very much appreciated!

@YovaKem
Copy link

YovaKem commented May 12, 2023

I managed to get the code to run with some changes to the forward() and backward() functions:

class RwkvLinearAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):

        batch_size, seq_len, hidden_size = key.size()
        if seq_len > rwkv_cuda_kernel.max_seq_length:
            raise ValueError(
                f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
                f"{rwkv_cuda_kernel.max_seq_length} with this model."
            )
        if batch_size * hidden_size % min(hidden_size, 32) != 0:
            raise ValueError(
                f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
                f"multiple of {min(hidden_size, 32)}."
            )

        ctx.input_dtype = key.dtype

        if (
            time_decay.device.type != "cuda"
            or time_first.device.type != "cuda"
            or key.device.type != "cuda"
            or value.device.type != "cuda"
        ):
            raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")

        time_decay = -torch.exp(time_decay.float().contiguous())
        if key.dtype == torch.float16:
            time_first = time_first.float()
            key = key.float()
            value = value.float()
        time_first = time_first.contiguous()
        key = key.contiguous()
        value = value.contiguous()
        # The CUDA kernel will fill this tensor.
        output = torch.empty_like(key, memory_format=torch.contiguous_format)
        if return_state or state is not None:
            if state is None:
                state = torch.zeros(
                    batch_size,
                    hidden_size,
                    3,
                    dtype=torch.float32,
                    device=key.device,
                    memory_format=torch.contiguous_format,
                )
                state[:, :, 2] -= 1e38
            else:
                state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()

            if key.dtype == torch.bfloat16:
                forward_func = rwkv_cuda_kernel.forward_with_state_bf16
            else:
                forward_func = rwkv_cuda_kernel.forward_with_state
            forward_func(time_decay, time_first.to(key.dtype), key, value, output, state)
        else:
            forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
            forward_func(time_decay, time_first.to(key.dtype), key, value, output)
        ctx.save_for_backward(time_decay, time_first, key, value, output)

        if state is not None:
            state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]

        return output.to(ctx.input_dtype), state
    def backward(ctx, g_output, g_state):
        input_dtype = ctx.input_dtype

        time_decay, time_first, key, value, output = ctx.saved_tensors
        # The CUDA kernel will fill those tensors.
        g_time_decay = torch.empty_like(
            time_decay,
            memory_format=torch.contiguous_format,
            dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        )
        g_time_first = torch.empty_like(
                time_first,
                memory_format=torch.contiguous_format,
                dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        )
        g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
        g_value = torch.empty_like(value, memory_format=torch.contiguous_format)

        if input_dtype == torch.float16:
            g_output = g_output.float()
        backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
        backward_func(
            time_decay,
            time_first.to(key.dtype),
            key,
            value,
            output,
            g_output.contiguous(),
            g_time_decay,
            g_time_first,
            g_key,
           g_value,
        )
        #g_time_decay = torch.sum(g_time_decay, dim=0)
        #g_time_first = torch.sum(g_time_first, dim=0)

        return (
            g_time_decay.to(input_dtype),
            g_time_first.to(input_dtype),
            g_key.to(input_dtype),
            g_value.to(input_dtype),
            None,
            None
        )

One problem I run into now is that although I'm trying to train a fairly small model (12 layers, 256 hidden size, 64 context size) I can only train with a very small batch size (16) on a 40GB A100 card. For comparison, a RoBERTa model with a similar size allows for a bs of 256. This seems counterintuitive to me, but I might be wrong.

Another issue I observed is instability: in some cases, within the first 3 steps of training the loss goes from something normal like 10 to 90543067814198.3 and then to 0.0. This seems to happen more when bf16 training is disabled and at higher batch sizes when bf16 training is enabled.

@Blealtan
Copy link

Blealtan commented May 13, 2023

@YovaKem Would you mind try change this

# The CUDA kernel will fill those tensors.
g_time_decay = torch.empty_like(
    time_decay,
    memory_format=torch.contiguous_format,
    dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)

to

# The CUDA kernel will fill those tensors.
g_time_decay = torch.empty(
    key.shape[0], key.shape[2],
    memory_format=torch.contiguous_format,
    dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty(k.shape[0], k.shape[2], memory_format=torch.contiguous_format)

I suspect there's an overflow in the current code, as mentioned above in the review comment but not tested yet. The binary distribution on PyPI does not include the cuda kernels XD

Also, the gradient of the state should be computed, but the current kernel is not doing it. Later after I setup the env I'll open the PR.

@YovaKem
Copy link

YovaKem commented May 13, 2023

Thanks @Blealtan! I guess you meant k for key? I added bf16 support for g_time_first (I get an error otherwise) and put the tensors on CUDA

        # The CUDA kernel will fill those tensors.
        g_time_decay = torch.empty(
            key.shape[0], key.shape[2],
            memory_format=torch.contiguous_format,
            dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        ).to(key.device)
        g_time_first = torch.empty(
                key.shape[0], key.shape[2],
                memory_format=torch.contiguous_format,
                dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        ).to(key.device)

This seems to solve both the OOM issue and the instability!

One question re your comment of state gradients - I now saw this

It will also match the _with_state variant of WKV forward.

In what cases is the _with_state variant used? As far as I can see the model I'm training is not passing states at all during the forward step. Is that something that only becomes relevant an inference time when the model is used like an RNN?

@lambdaofgod
Copy link

Hey @sgugger how did you prepare the models? Could you point us how to convert original .pth or .safetensors model to your format? Thanks!

PS
Awesome RWKV joined transformers!

@amyeroberts
Copy link
Collaborator

@lambdaofgod The logic used to convert the RWKV checkpoints from BlinkDL to HF format can be found in the conversion script.

@Blealtan
Copy link

@YovaKem AFAIK, with_state is used only in inference now (in existing non-transformers implementations throughout the RWKV community). However, with proper implementation, this will allow more efficient training on long sequences, but it has not yet been implemented.

@sgugger
Copy link
Collaborator Author

sgugger commented May 16, 2023

I have no idea why the CUDA kernels all disappeared from the pacakge on Pypi (it's not just RWKV, but all models using custom kernels). Will investigate later today and post a patch release when I find a solution.

@sgugger
Copy link
Collaborator Author

sgugger commented May 16, 2023

Normally custom kernels should be included in 4.29.2, sorry for the inconvenience. We added stronger to checks to make sure they don't disappear again in a future release.

@Wednesday657
Copy link

Hi, can i ask a simple question about RWKV kernel? The rwkv model without customized kernel uses a for loop here:

for current_index in range(seq_length):
current_key = key[:, current_index].float()
current_value = value[:, current_index]
# wkv computation at time t
max_for_output = torch.maximum(max_state, current_key + time_first)
e1 = torch.exp(max_state - max_for_output)
e2 = torch.exp(current_key + time_first - max_for_output)
numerator = e1 * num_state + e2 * current_value
denominator = e1 * den_state + e2
output[:, current_index] = (numerator / denominator).to(output.dtype)
# Update state for next iteration
max_for_state = torch.maximum(max_state + time_decay, current_key)
e1 = torch.exp(max_state + time_decay - max_for_state)
e2 = torch.exp(current_key - max_for_state)
num_state = e1 * num_state + e2 * current_value
den_state = e1 * den_state + e2
max_state = max_for_state

I am not familiar with cuda kernel. So i am not sure whether the customized cuda kernel still computes sequentially and delivers a faster for loop, or just make the computation parallelized in GPU?

@fullstackwebdev
Copy link

Putting this here so it doesn't get lost.

I am trying to run microsoft guidance (https://github.com/microsoft/guidance) on RWKV through transformers and I am getting an error

AttributeError: 'RwkvCausalLMOutput' object has no attribute 'past_key_values'

which can be reproduced here: https://gist.github.com/fullstackwebdev/a6523374e6687825fcb92ca74048c12b

@younesbelkada
Copy link
Contributor

@fullstackwebdev
I don't think the fix should go inside transformers as this means we should always output past_key_values=None - which is quite misleading as by desing RWKV does not rely on past_key_values for caching - as the tokens are processed one by one. I made guidance-ai/guidance#91 that fixed the issue in my local env

@Blealtan Blealtan mentioned this pull request May 24, 2023
4 tasks
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* First draft of RWKV-4

* Add support for generate

* Style post-rebase

* Properly use state

* Write doc

* Fix doc

* More math

* Add model to README, dummies and clean config

* Fix init

* multiple fixes:

- fix common tests
- fix configuraion default values
- add CI test for checking state computation
- fix some CI tests

* correct tokenizer

* some tweaks

- fix config docstring
- fix failing tests

* fix CI tests

- add output_attention / output_hidden_states
- override test_initialization
- fix failing CIs

* fix conversion script

- fix sharded case
- add new arguments

* add slow tests + more fixes on conversion script

* add another test

* final fixes

* change single name variable

* add mock attention mask for pipeline to work

* correct eos token id

* fix nits

* add checkpoints

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add `tie_word_embeddings` in docstring

* change tensor name

* fix final nits

* Trigger CI

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* First draft of RWKV-4

* Add support for generate

* Style post-rebase

* Properly use state

* Write doc

* Fix doc

* More math

* Add model to README, dummies and clean config

* Fix init

* multiple fixes:

- fix common tests
- fix configuraion default values
- add CI test for checking state computation
- fix some CI tests

* correct tokenizer

* some tweaks

- fix config docstring
- fix failing tests

* fix CI tests

- add output_attention / output_hidden_states
- override test_initialization
- fix failing CIs

* fix conversion script

- fix sharded case
- add new arguments

* add slow tests + more fixes on conversion script

* add another test

* final fixes

* change single name variable

* add mock attention mask for pipeline to work

* correct eos token id

* fix nits

* add checkpoints

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add `tie_word_embeddings` in docstring

* change tensor name

* fix final nits

* Trigger CI

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@amyeroberts amyeroberts mentioned this pull request Aug 2, 2023
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RWKV4neo Add RWKV2 (fast)
10 participants