Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp16] model generates NaN results on fp16, while it generates correct results on fp32 #11384

Closed
yetingqiaqia opened this issue Apr 28, 2022 · 9 comments

Comments

@yetingqiaqia
Copy link

Describe the bug

Hi ORT team,

We use fp16 to accelerate the model inference speed. It works fine on many models. But we met NaN issue on a new fp16 model, while its fp32 version generates correct results. See below:

Fp32 model Fp16 model
Result image image

Could you help check if there anything wrong on the fp16 model?

Urgency
The onnx model was converted from pyTorch model. Its fp32 model speed is slower than pyTorch. We hope fp16 could help accelerate its speed by 3x, if resolving this NaN issue. Thanks.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 18.04
  • ONNX Runtime installed from (source or binary): binary
  • ONNX Runtime version: 1.10
  • Python version: 3.6
  • Visual Studio version (if applicable): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: CUDA 11.5, CuDNN 8.3.1
  • GPU model and memory: V100, 16GB

BTW, the issue can be reproduced in different CUDA/cuDNN versions or GPU SKUs, so I don't think they matter.

To Reproduce

  • Test code:
def convert_float32_to_float16(fp32_model_path, fp16_model_path):
    from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path
    from onnxmltools.utils import save_model

    new_onnx_model = convert_float_to_float16_model_path(fp32_model_path, keep_io_types=True)
    save_model(new_onnx_model, fp16_model_path)


def test(onnx_model_path):
    import numpy as np
    import time
    np.random.seed(123)
    #Load ort model
    import onnxruntime as ort
    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    sess_options.intra_op_num_threads = 0
    sess = ort.InferenceSession(onnx_model_path, sess_options, providers=['CUDAExecutionProvider'])
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name

    data = 2*np.random.rand(8, 3, 384, 384).astype(np.float32)-1.0

    #warm-up run
    warm_up_start_stamp = time.time()
    onnx_outs = sess.run([label_name], {input_name: data})[0]
    print(f"onnx_outs of warm-up:", onnx_outs)
    print(f"It takes {time.time()-warm_up_start_stamp} to finish warm-up.\n")

    start_stamp = time.time()
    num_batches = 10
    for i in range(num_batches):
        print(f"batch id: {i}")
        data = 2*np.random.rand(8, 3, 384, 384).astype(np.float32)-1.0
        onnx_outs = sess.run([label_name], {input_name: data})[0]
        print(f"onnx_outs:", onnx_outs)
        print(f"{i}th batch finished successfully. ")
    print(f"It takes {time.time()-start_stamp} to finish {num_batches} batches.\n")


fp32_model_path = './ConvNext-XL/graph.onnx'
fp16_model_path = './ConvNext-XL-fp16/graph_fp16.onnx'
convert_float32_to_float16(fp32_model_path, fp16_model_path)

test(fp32_model_path)
test(fp16_model_path)
@yetingqiaqia
Copy link
Author

Sorry, I was using a wrong permission on the shared file. Now, it has been changed to the right permission. Please retry if you met permission issue on viewing the shared file. Thanks.

@garymm
Copy link
Contributor

garymm commented Apr 28, 2022

@yetingqiaqia it's somewhat suspicious that the test always uses float32 as the input type. Can you try using float16 data as the input to the float16 model? If that fixes it, then I guess just close this issue.
Can you try using the CPU Execution Provider? If that fixes it, then the problem is with the CUDA Execution Provider.

If neither of those work, please file an issue in https://github.com/microsoft/onnxconverter-common and CC me and xiaowuhu.

@BowenBao
Copy link
Contributor

Hi @yetingqiaqia, please resort to this tool https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/auto_mixed_precision.py to properly convert fp32 model to fp16 model.

@garymm
Copy link
Contributor

garymm commented Apr 28, 2022

@BowenBao should we delete onnxmltools.utils.float16_converter?

@BowenBao
Copy link
Contributor

BowenBao commented Apr 28, 2022

@BowenBao should we delete onnxmltools.utils.float16_converter?

Indeed it feels more reasonable to replace the lower level float16 apis with auto_mixed_precision api. Or at least add the latter there.

Created onnx/onnxmltools#543

@yetingqiaqia
Copy link
Author

Thanks @BowenBao . auto_mixed_precision api works.
I found two issues:

  1. The first one is a tiny one that I have to call auto_mixed_precision.auto_convert_mixed_precision() instead of auto_convert_mixed_precision() in your file comments.

Example code:

def convert_float32_to_mixed_precision(fp32_model_path, mixed_precision_model_path):
    from onnxconverter_common import auto_mixed_precision
    import onnx

    model = onnx.load(fp32_model_path)

    import numpy as np
    np.random.seed(123)
    test_data = {"input_image": 2*np.random.rand(8, 3, 384, 384).astype(np.float32)-1.0}

    # Could also use rtol/atol attributes directly instead of this
    def validate(res1, res2):
        for r1, r2 in zip(res1, res2):
            if not np.allclose(r1, r2, rtol=0.01, atol=0.001):
                return False
        return True

    model_fp16 = auto_mixed_precision.auto_convert_mixed_precision(model, test_data, validate, keep_io_types=True)
    onnx.save(model_fp16, mixed_precision_model_path)

fp32_model_path = './ConvNext-XL/graph.onnx'
mixed_precision_model_path = './ConvNext-XL_mixed_precision/graph_mixed_precision.onnx'

print("Convert to mixed precision starts...")
convert_float32_to_mixed_precision(fp32_model_path, mixed_precision_model_path)
print("Conversion finished.")
  1. The second one, a big concern is, this auto_convert_mixed_precision() api is so slow, it seems to scan a range of combinations to find the optimal one. For example, in my case, it attempted 52 times, which took 16mins to finish.
    It is OK for end users as they only need to convert it once. However, it will be an issue for us. We actually maintain a DL platform called AdsBrain which serves lots of users' models. We will by default run the model with fp16 to shorten the serving time. We prefer the fp16 conversion to be fast. For example, in our platform, we use graph_options=tf.GraphOptions(enable_bfloat16_sendrecv=True) for Tensorflow models, and for pyTorch, it has torch.cuda.amp; ``convert_float_to_float16_model_path()``` for ONNX. For onnx, if users' models are fp32 models, they will be converted to fp16. But if the ONNX fp16 conversion is so slow, it will be a huge cost.

So, if auto_convert_mixed_precision() couldn't run faster. We will still prefer to call convert_float_to_float16_model_path() for Onnx fp32 models in our scenarios. However, your new api also help a lot. If any users meet this nan issue in the future. We would ask them to call this api to convert to mixed_precision model first and then run on our platform.
Thanks.

@yetingqiaqia
Copy link
Author

Thanks @garymm . float32 as input is by purpose, which shouldn't bring in the nan issue. In the convert APIs both in auto_convert_mixed_precision() and convert_float_to_float16_model_path(), you can see a parameter called keep_io_types=True. By enabling this parameter, the original IO types will be kept. And this feature was asked by our AdsBrain users.

For deleting the onnxmltools.utils.float16_converter, my personal opinion is that I wouldn't want to delete it if the auto_convert_mixed_precision() couldn't run faster. I have explained my concern in previous comment.
Thanks.

@garymm
Copy link
Contributor

garymm commented Apr 29, 2022

Let's continue the discussion of what to do about the two APIs in onnx/onnxmltools#543.
Closing this since the issue is resolved.

@garymm garymm closed this as completed Apr 29, 2022
@hariharans29
Copy link
Member

hariharans29 commented Apr 29, 2022

Thanks @BowenBao and @garymm !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants