Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blank output in the Inference while using a customize trained T5 model #226

Open
gaurav21s opened this issue Jan 2, 2023 · 10 comments · May be fixed by #270
Open

Blank output in the Inference while using a customize trained T5 model #226

gaurav21s opened this issue Jan 2, 2023 · 10 comments · May be fixed by #270
Assignees
Labels
bug Something isn't working

Comments

@gaurav21s
Copy link

Hi,
When I am using the normal version of the t5 models like 't5-small' or 't5-base'. It is working fine and I am getting the output. But when I tried my customized trained t5 base model for my data. After the kernl optimization, the output is blank.

Can you please look into that or do you have any idea about that?

@jonathlela jonathlela added the bug Something isn't working label Jan 5, 2023
@jonathlela jonathlela self-assigned this Jan 5, 2023
@jonathlela
Copy link
Collaborator

Hi @gaurav21s

By blank, you mean all the outputs are empty-zeroed tensors ? Do you have any warning / error messages ?

@gaurav21s
Copy link
Author

gaurav21s commented Jan 5, 2023 via email

@gaurav21s
Copy link
Author

gaurav21s commented Jan 5, 2023 via email

@wangjunhaoumich
Copy link

wangjunhaoumich commented Jan 7, 2023

Can confirm this problem exists also in e2e t5 tutorial, if one changes t5-small to t5-base, the output of .generate in the last cell (after optimization) would all be 0.

For environment, I used this repo's Dockerfile to create VScode dev container inside Windows wsl2 with RTX3090-24GB.

@jonathlela
Copy link
Collaborator

jonathlela commented Jan 11, 2023

Thanks for your reports, I can confirm this bug and we're investigating it.

Simple code to reproduce it:

import torch                                                                                             
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer                                            
                                                                                                         
from kernl.model_optimization import optimize_model                                                      
                                                                                                         
                                                                                                         
model_name = "t5-base"                                                                                   
                                                                                                         
model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_name).eval().cuda()    
                                                                                                         
tokenizer = AutoTokenizer.from_pretrained(model_name)                                                    
                                                                                                         
input_ids = tokenizer(                                                                                   
    "translate English to French: The house in the woods is wonderful, can we buy it ?",                 
    return_tensors="pt",                                                                                 
    pad_to_multiple_of=8,                                                                                
    padding=True,                                                                                        
).to("cuda")                                                                                             
                                                                                                         
optimize_model(model.encoder)                                                                            
optimize_model(model.decoder)                                                                            
                                                                                                         
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=Tr\
ue):                                                                                                     
    output = model.generate(input_ids["input_ids"], min_length=22, max_length=22,)                       
    print(output[0])                                                                                     
    print(tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)) 

displays:

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')

Disabling replace_layer_norm_rms fix the output

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jan 14, 2023

@jonathlela can you share a reproduction at the rmsnorm kernel level ?

Simple analysis (local machine, RTX 3090) seems to show that the input of rmsnorm kernel contains NaN values in large T5 flavour.
It happens even WITHOUT Kernl optimizations.
With base model, the model seems to work (the output is correct).
@jonathlela to print intermediate values, don't forget to remove CUDA graphs wich doesn't like print() or raise Exception (it will segfault).

@gaurav21s how did you trained T5? With fp16 or bf16/fp32 ?
If trained in BF16 I would not be surprised if some tensors are outside of the FP16 range.

T5 base output:

/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
  warnings.warn(
tensor([    0,   325,  4053,   247,   110,  5912,   259, 27015,  1074,     6,
         3215,   106,     7,    18, 10529,     3,    40,    31, 13541,     3,
           58,     3], device='cuda:0')
La maison dans les bois est merveilleuse, pouvons-nous l'acheter? 

T5 large:

❯ python t5.py
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.21k/1.21k [00:00<00:00, 911kB/s]
/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
  warnings.warn(
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')

@Pavelrst
Copy link

Pavelrst commented Jan 24, 2023

@jonathlela I have the same issue with optimized t5-base. You mentioned that disabling "replace_layer_norm_rms" fixes the issue. How replace_layer_norm_rms can be disabled?

I've tried commenting out this line and re-running the code:

replace_layer_norm_rms(gm)
but still got empty output.

@jonathlela
Copy link
Collaborator

T5 weights are in BF16, triton 2.0 does not support fully BF16, we're waiting for the fix to be propagated.
triton-lang/triton#1306

@hchoi-moveworks
Copy link

Hello @jonathlela

triton-lang/triton#1306 as the PR is merged, would using most recent openai triton with kernl resolve this issue?

@hchoi-moveworks
Copy link

Hello @jonathlela

Would there be other large models we can try with Kernl?

It seems like larger version of T5 model type does not work due to this issue.
Also GPT model also seems to have errors: #146

Could we try GPT model? Is the above issue fixed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants