diff --git a/olive/hardware/constants.py b/olive/hardware/constants.py index 9146e3891..2856fd120 100644 --- a/olive/hardware/constants.py +++ b/olive/hardware/constants.py @@ -28,7 +28,7 @@ "MIGraphXExecutionProvider", "TensorrtExecutionProvider", "OpenVINOExecutionProvider", - "JsExecutionProvider" + "JsExecutionProvider", ], "npu": ["QNNExecutionProvider"], } diff --git a/olive/passes/onnx/float32_conversion.py b/olive/passes/onnx/float32_conversion.py index 6c63b691f..211a5ce12 100644 --- a/olive/passes/onnx/float32_conversion.py +++ b/olive/passes/onnx/float32_conversion.py @@ -2,33 +2,36 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import logging +import re +from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict -from collections import defaultdict import onnx + from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam -import re -class OnnxIOFloat16ToFloat32(Pass): - """Converts float16 model inputs/outputs to float32. +logger = logging.getLogger(__name__) + - """ +class OnnxIOFloat16ToFloat32(Pass): + """Converts float16 model inputs/outputs to float32.""" @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: config = { "name_pattern": PassConfigParam( - type_=str, default_value="logits", + type_=str, + default_value="logits", description=( - "Only convert inputs/outputs whose name matches this pattern. By default" - "looking for logits names" - ) + "Only convert inputs/outputs whose name matches this pattern. By defaultlooking for logits names" + ), ) } config.update(get_external_data_config()) @@ -40,7 +43,7 @@ def create_io_mapping(self, graph, i_map, o_map): i_map[i].append(n) for n in graph.node: for o in n.output: - assert o not in o_map[o] + assert o not in o_map o_map[o] = [n] def wrap_inputs(self, graph, i_map, names): @@ -54,7 +57,7 @@ def wrap_inputs(self, graph, i_map, names): match = names.search(i.name) if not match: continue - print(f"input {i.name} from fp32") + logger.debug("input %s from fp32", i.name) for n in i_map[i.name]: for j, o in enumerate(n.input): if o == i.name: @@ -63,12 +66,11 @@ def wrap_inputs(self, graph, i_map, names): "Cast", inputs=[i.name], outputs=[i.name + "_fp16"], - to=onnx.TensorProto.FLOAT16, + to=onnx.TensorProto.FLOAT, ) graph.node.insert(0, cast) i.type.tensor_type.elem_type = onnx.TensorProto.FLOAT - def wrap_outputs(self, graph, i_map, o_map, names): # 1. find fp16 outputs # 2. rewrite all providers @@ -80,7 +82,7 @@ def wrap_outputs(self, graph, i_map, o_map, names): match = names.search(o.name) if not match: continue - print(f"output {o.name} to fp32") + logger.debug("output %s from fp32", o.name) for n in o_map[o.name]: for j, i in enumerate(n.output): if i == o.name: diff --git a/test/unit_test/passes/onnx/test_float32_conversion.py b/test/unit_test/passes/onnx/test_float32_conversion.py new file mode 100644 index 000000000..a0385aeee --- /dev/null +++ b/test/unit_test/passes/onnx/test_float32_conversion.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from olive.model.handler.onnx import ONNXModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.float32_conversion import OnnxIOFloat16ToFloat32 +from test.unit_test.utils import get_onnx_model +import onnx +from onnx import helper +from onnx import TensorProto + + +def test_onnx_io_ft16_to_ft32_conversion(tmp_path): + # setup + node1 = helper.make_node( + 'Add', + ['logits_A', 'logits_B'], + ['logits_C'], + name='add_node' + ) + + input_tensor_A = helper.make_tensor_value_info('logits_A', TensorProto.FLOAT16, [None]) + input_tensor_B = helper.make_tensor_value_info('logits_B', TensorProto.FLOAT16, [None]) + output_tensor_C = helper.make_tensor_value_info('logits_C', TensorProto.FLOAT16, [None]) + + graph = helper.make_graph( + [node1], + 'example_graph', + [input_tensor_A, input_tensor_B], + [output_tensor_C] + ) + onnx_model = helper.make_model(graph, producer_name='example_producer') + tmp_model_path = str(tmp_path / "model.onnx") + onnx.save(onnx_model, tmp_model_path) + input_model = ONNXModelHandler(model_path=tmp_model_path) + p = create_pass_from_dict(OnnxIOFloat16ToFloat32, None, disable_search=True) + output_folder = str(tmp_path / "onnx") + + # execute + output_model = p.run(input_model, None, output_folder) + + # assert + for input in output_model.get_graph().input: + assert input.type.tensor_type.elem_type == onnx.TensorProto.FLOAT + for output in output_model.get_graph().output: + assert output.type.tensor_type.elem_type == onnx.TensorProto.FLOAT \ No newline at end of file