-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference using the CUDA EP returns nan #15752
Comments
how about running FP32 model with CUDA EP? |
CPU will use fp32 to run the model so it is fine. It seems SimplifiedLayerNormalization has issue in FP16 based on dumping node outputs. You can put it to op_block_list.
|
Is this an actual onnx op? Or some cuda kernel that results from fusion? I can't find this op in my graph or in https://github.com/onnx/onnx/blob/main/docs/Operators.md. Following @wangyems 's advice, I was able to convert to fp16 and run inference with CUDA EP using the following op_block_list: FP16_BAD_OPS = [
"Add",
"MatMul",
"Mul",
"Pow",
"ReduceMean",
"Sqrt",
] Removing any of these ops from the list results in a nan or all-zero output (uploaded a new model with these ops blocked to the google drive). However, I'm still getting all zeros from the TRT EP even with these ops blocked. |
The op is from fusion, You need run fusion before converting to fp16. BTW, we have scripts can help export T5 to fp16, or use in beam search: For example,
This is the op_block_list we used: onnxruntime/onnxruntime/python/tools/transformers/models/t5/t5_helper.py Lines 153 to 157 in abdd4f5
|
Thanks @tianleiwu ! Will definitely take a look. Do you have any clue about what might be wrong with the TRT EP? |
@omera-deci, For TRT, you need use FP32 raw onnx models. TRT will change it to fp16 internally. |
@tianleiwu I just tried to give the TRT EP the fp32 model. If I don't enable fp16 everything works smoothly, but once I enable fp16 the output is all zeros again. I've uploaded the fp32 model to the drive as well as a new script to reproduce. I guess some layers are overflowing in trt as well - anyway I can block their conversion the same way I did with onnx? |
@omera-deci, you can follow https://github.com/NVIDIA/TensorRT/blob/release/8.6/demo/HuggingFace/T5 to export onnx for T5 and run it in TRT EP. I did not see special setting so export onnx might be the key. You can run those scripts and get the onnx models to run in TRT EP. You will need build from source to support TRT 8.6, and use some new features (like trt_layer_norm_fp32_fallback and explicit input profiles). See the following doc for detail: |
hello,I had convert a fp32 model to fp16 model and when using onnx inference, we meet similar problem with you,but I don't know what is the FP16_BAD_OPS,and where it is. |
@changdong1687, see example script: onnxruntime/onnxruntime/python/tools/transformers/models/t5/t5_helper.py Lines 152 to 217 in 2580d93
You can define your own list of op_block_list for a model. |
Ok, got it, thank you! |
Describe the issue
I have an onnx model (a t5 encoder that I exported from pytorch and then converted to FP16 using
onnxruntime.transformers.float16.convert_float_to_float16
). When I use this model in an inference session that uses the CPU EP it works flawlessly, but running the same model in a session that uses the CUDA EP returns all nans as output.edit: Tried the TRT EP and it fails as well (returns all zeros).
I'm aware of #9629, #831 and #11384 but they all seem either very model-specific or return nans on CPU EP as well, which is not my case.
To reproduce
I wrote this small snippet to reproduce (I hope the issue is not my reliance on the nvidia pip libraries). The onnx model can be downloaded from here: https://drive.google.com/drive/folders/1AMNI_cRYn31owMstIvdsW4IcOcRAYvC_?usp=share_link
I'm using cuda 11.7 on Ubuntu 22.04.
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.14.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU, CUDA
Execution Provider Library Version
CUDA 11.7
TRT 8.5.3.1
The text was updated successfully, but these errors were encountered: