Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference using the CUDA EP returns nan #15752

Open
omera-nv opened this issue Apr 30, 2023 · 11 comments
Open

Inference using the CUDA EP returns nan #15752

omera-nv opened this issue Apr 30, 2023 · 11 comments
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@omera-nv
Copy link

omera-nv commented Apr 30, 2023

Describe the issue

I have an onnx model (a t5 encoder that I exported from pytorch and then converted to FP16 using onnxruntime.transformers.float16.convert_float_to_float16). When I use this model in an inference session that uses the CPU EP it works flawlessly, but running the same model in a session that uses the CUDA EP returns all nans as output.
edit: Tried the TRT EP and it fails as well (returns all zeros).

I'm aware of #9629, #831 and #11384 but they all seem either very model-specific or return nans on CPU EP as well, which is not my case.

To reproduce

I wrote this small snippet to reproduce (I hope the issue is not my reliance on the nvidia pip libraries). The onnx model can be downloaded from here: https://drive.google.com/drive/folders/1AMNI_cRYn31owMstIvdsW4IcOcRAYvC_?usp=share_link

pip install onnxruntime-gpu nvidia-cuda-runtime-cu11 nvidia-cufft-cu11 nvidia-curand-cu11 nvidia-cublas-cu11 nvidia-cudnn-cu11
#!/usr/bin/env python3
import ctypes
from pathlib import Path
import numpy as np


def print_cuda_ep_libs_versions():
    import importlib.metadata

    for lib in [
        "numpy",
        "onnxruntime-gpu",
        "tensorrt",
        "nvidia-cuda-runtime-cu11",
        "nvidia-cufft-cu11",
        "nvidia-curand-cu11",
        "nvidia-cublas-cu11",
        "nvidia-cudnn-cu11",
    ]:
        print(lib, importlib.metadata.version(lib))


def load_cuda_ep_native_deps():
    import nvidia.cuda_runtime.lib
    import nvidia.cufft.lib
    import nvidia.curand.lib
    import nvidia.cublas.lib
    import nvidia.cudnn.lib

    load_native_lib(Path(nvidia.cuda_runtime.lib.__path__[0]) / "libcudart.so.11.0")
    load_native_lib(Path(nvidia.cufft.lib.__path__[0]) / "libcufft.so.10")
    load_native_lib(Path(nvidia.curand.lib.__path__[0]) / "libcurand.so.10")
    load_native_lib(Path(nvidia.cublas.lib.__path__[0]) / "libcublas.so.11")
    load_native_lib(Path(nvidia.cublas.lib.__path__[0]) / "libcublasLt.so.11")
    load_native_lib(Path(nvidia.cudnn.lib.__path__[0]) / "libcudnn.so.8")


def load_native_lib(library_path):
    ctypes.CDLL(library_path, mode=ctypes.RTLD_GLOBAL)


if __name__ == "__main__":
    print_cuda_ep_libs_versions()
    load_cuda_ep_native_deps()
    import tensorrt
    import onnxruntime as ort

    ort_inputs = {"input_ids": np.ones((1, 256), dtype=np.int64), "attention_mask": np.ones((1, 256), dtype=np.int64)}

    trt_sess = ort.InferenceSession(
        "t5_fp16_encoder.onnx",
        providers=[
            ("TensorrtExecutionProvider", {"trt_fp16_enable": True}),
            "CUDAExecutionProvider",
            "CPUExecutionProvider",
        ],
    )
    cuda_sess = ort.InferenceSession(
        "t5_fp16_encoder.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
    )
    cpu_sess = ort.InferenceSession("t5_fp16_encoder.onnx", providers=["CPUExecutionProvider"])

    print("TRT:", trt_sess.run(None, ort_inputs))
    print("CUDA:", cuda_sess.run(None, ort_inputs))
    print("CPU:", cpu_sess.run(None, ort_inputs))

I'm using cuda 11.7 on Ubuntu 22.04.

Urgency

No response

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU, CUDA

Execution Provider Library Version

CUDA 11.7
TRT 8.5.3.1

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. ep:TensorRT issues related to TensorRT execution provider labels Apr 30, 2023
@wangyems
Copy link
Contributor

wangyems commented May 1, 2023

how about running FP32 model with CUDA EP?
If FP32 is good, then you can try mixed precision conversion by specifying op_block_list. code example

@tianleiwu
Copy link
Contributor

tianleiwu commented May 2, 2023

CPU will use fp32 to run the model so it is fine. It seems SimplifiedLayerNormalization has issue in FP16 based on dumping node outputs. You can put it to op_block_list.

SimplifiedLayerNormalization node: SimplifiedLayerNormalization_token_210
Input 0 Name: /model/block.6/layer.1/Add_output_0
 Shape: {1,256,512}
OrtMemoryInfo:[name:Cuda id:0 OrtMemType:0 OrtAllocatorType:1 Device:[DeviceType:1 MemoryType:0 DeviceId:0]]
-23.15625, 106.1875, -46.09375, ... , -136, -44.75, 5416
-23.1875, 106.1875, -46.125, ... , -136, -44.8125, 5416
-23.15625, 106.1875, -46.125, ... , -136, -44.75, 5416
...
-23.125, 106.1875, -46.125, ... , -136, -44.75, 5416
-23.15625, 106.125, -46.125, ... , -136, -44.75, 5416
-23.21875, 106.1875, -46.09375, ... , -136, -44.75, 5416

Input 1 Name: model.block.7.layer.0.layer_norm.weight
 Shape: {512}
OrtMemoryInfo:[name:Cuda id:0 OrtMemType:0 OrtAllocatorType:1 Device:[DeviceType:1 MemoryType:0 DeviceId:0]]
0.22058105, 0.18444824, 0.1887207, ... , 0.17089844, 0.18896484, 0.098571777

Placement: CUDAExecutionProvider
-----------
Output 0 Name: /model/block.7/layer.0/layer_norm/Mul_1_output_0
 Shape: {1,256,512}
OrtMemoryInfo:[name:Cuda id:0 OrtMemType:0 OrtAllocatorType:1 Device:[DeviceType:1 MemoryType:0 DeviceId:0]]
-0, 0, -0, ... , -0, -0, 0
-0, 0, -0, ... , -0, -0, 0
-0, 0, -0, ... , -0, -0, 0
...
-0, 0, -0, ... , -0, -0, 0
-0, 0, -0, ... , -0, -0, 0
-0, 0, -0, ... , -0, -0, 0

Min=-0,Max=-0,Zero=131072

@omera-nv
Copy link
Author

omera-nv commented May 2, 2023

SimplifiedLayerNormalization

Is this an actual onnx op? Or some cuda kernel that results from fusion? I can't find this op in my graph or in https://github.com/onnx/onnx/blob/main/docs/Operators.md.

Following @wangyems 's advice, I was able to convert to fp16 and run inference with CUDA EP using the following op_block_list:

FP16_BAD_OPS = [
    "Add",
    "MatMul",
    "Mul",
    "Pow",
    "ReduceMean",
    "Sqrt",
]

Removing any of these ops from the list results in a nan or all-zero output (uploaded a new model with these ops blocked to the google drive). However, I'm still getting all zeros from the TRT EP even with these ops blocked.

@tianleiwu
Copy link
Contributor

tianleiwu commented May 2, 2023

The op is from fusion, You need run fusion before converting to fp16.

BTW, we have scripts can help export T5 to fp16, or use in beam search:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/convert_generation.py

For example,

python -m onnxruntime.transformers.models.t5.convert_to_onnx -m t5-small -o -p fp16 --use_gpu --separate_encoder_and_decoder_init

This is the op_block_list we used:

op_block_list: List[str] = [ # noqa: B006
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",

@omera-nv
Copy link
Author

omera-nv commented May 3, 2023

Thanks @tianleiwu ! Will definitely take a look. Do you have any clue about what might be wrong with the TRT EP?

@tianleiwu
Copy link
Contributor

tianleiwu commented May 4, 2023

@omera-deci, For TRT, you need use FP32 raw onnx models. TRT will change it to fp16 internally.

@omera-nv
Copy link
Author

omera-nv commented May 7, 2023

@tianleiwu I just tried to give the TRT EP the fp32 model. If I don't enable fp16 everything works smoothly, but once I enable fp16 the output is all zeros again. I've uploaded the fp32 model to the drive as well as a new script to reproduce. I guess some layers are overflowing in trt as well - anyway I can block their conversion the same way I did with onnx?

@tianleiwu
Copy link
Contributor

tianleiwu commented May 7, 2023

@omera-deci, you can follow https://github.com/NVIDIA/TensorRT/blob/release/8.6/demo/HuggingFace/T5 to export onnx for T5 and run it in TRT EP. I did not see special setting so export onnx might be the key. You can run those scripts and get the onnx models to run in TRT EP.

You will need build from source to support TRT 8.6, and use some new features (like trt_layer_norm_fp32_fallback and explicit input profiles). See the following doc for detail:
https://github.com/microsoft/onnxruntime/blob/fd080caf62db1b41463955286c49d6a582c6a45a/docs/execution-providers/TensorRT-ExecutionProvider.md
@chilo-ms for comments of fp16 in TRT EP

@changdong1687
Copy link

SimplifiedLayerNormalization

Is this an actual onnx op? Or some cuda kernel that results from fusion? I can't find this op in my graph or in https://github.com/onnx/onnx/blob/main/docs/Operators.md.

Following @wangyems 's advice, I was able to convert to fp16 and run inference with CUDA EP using the following op_block_list:

FP16_BAD_OPS = [
    "Add",
    "MatMul",
    "Mul",
    "Pow",
    "ReduceMean",
    "Sqrt",
]

Removing any of these ops from the list results in a nan or all-zero output (uploaded a new model with these ops blocked to the google drive). However, I'm still getting all zeros from the TRT EP even with these ops blocked.

hello,I had convert a fp32 model to fp16 model and when using onnx inference, we meet similar problem with you,but I don't know what is the FP16_BAD_OPS,and where it is.
best wishes for your reply

@tianleiwu
Copy link
Contributor

@changdong1687, see example script:

def auto_mixed_precision(
onnx_model: OnnxModel,
op_block_list: List[str] = [ # noqa: B006
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",
],
):
"""Convert model to mixed precision.
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
Args:
onnx_model (OnnxModel): optimized ONNX model
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
Returns:
parameters(dict): a dictionary of parameters used in float16 conversion
"""
op_full_set = {node.op_type for node in onnx_model.nodes()}
fp32_op_set = set(op_block_list)
fp16_op_set = op_full_set.difference(fp32_op_set)
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
# logits is the first output
logits_output_name = onnx_model.graph().output[0].name
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
is_weight_fp16_precision = False
output_name_to_node = onnx_model.output_name_to_node()
assert logits_output_name in output_name_to_node
node = output_name_to_node[logits_output_name]
last_matmul_node = None
if node.op_type == "MatMul":
last_matmul_node = node
logger.info(f"Found last MatMul node for logits: {node.name}")
initializer = None
for input in node.input:
initializer = onnx_model.get_initializer(input)
if initializer is not None:
break
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
# we can deduce that the weights are stored in float16 precision.
max_diff = float_to_float16_max_diff(initializer)
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
is_weight_fp16_precision = max_diff < 1e-6
else:
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
keep_io_types = []
node_block_list = []
if (not is_weight_fp16_precision) and (last_matmul_node is not None):
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
keep_io_types = [logits_output_name]
node_block_list = [last_matmul_node.name]
parameters = {
"keep_io_types": keep_io_types,
"op_block_list": op_block_list,
"node_block_list": node_block_list,
"force_fp16_initializers": is_weight_fp16_precision,
}
logger.info(f"auto_mixed_precision parameters: {parameters}")
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
return parameters

You can define your own list of op_block_list for a model.

@changdong1687
Copy link

@changdong1687, see example script:

def auto_mixed_precision(
onnx_model: OnnxModel,
op_block_list: List[str] = [ # noqa: B006
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",
],
):
"""Convert model to mixed precision.
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
Args:
onnx_model (OnnxModel): optimized ONNX model
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
Returns:
parameters(dict): a dictionary of parameters used in float16 conversion
"""
op_full_set = {node.op_type for node in onnx_model.nodes()}
fp32_op_set = set(op_block_list)
fp16_op_set = op_full_set.difference(fp32_op_set)
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
# logits is the first output
logits_output_name = onnx_model.graph().output[0].name
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
is_weight_fp16_precision = False
output_name_to_node = onnx_model.output_name_to_node()
assert logits_output_name in output_name_to_node
node = output_name_to_node[logits_output_name]
last_matmul_node = None
if node.op_type == "MatMul":
last_matmul_node = node
logger.info(f"Found last MatMul node for logits: {node.name}")
initializer = None
for input in node.input:
initializer = onnx_model.get_initializer(input)
if initializer is not None:
break
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
# we can deduce that the weights are stored in float16 precision.
max_diff = float_to_float16_max_diff(initializer)
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
is_weight_fp16_precision = max_diff < 1e-6
else:
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
keep_io_types = []
node_block_list = []
if (not is_weight_fp16_precision) and (last_matmul_node is not None):
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
keep_io_types = [logits_output_name]
node_block_list = [last_matmul_node.name]
parameters = {
"keep_io_types": keep_io_types,
"op_block_list": op_block_list,
"node_block_list": node_block_list,
"force_fp16_initializers": is_weight_fp16_precision,
}
logger.info(f"auto_mixed_precision parameters: {parameters}")
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
return parameters

You can define your own list of op_block_list for a model.

Ok, got it, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

4 participants