Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed May 13, 2024
1 parent 601f83b commit d7991ac
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
2 changes: 1 addition & 1 deletion olive/hardware/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"MIGraphXExecutionProvider",
"TensorrtExecutionProvider",
"OpenVINOExecutionProvider",
"JsExecutionProvider"
"JsExecutionProvider",
],
"npu": ["QNNExecutionProvider"],
}
32 changes: 17 additions & 15 deletions olive/passes/onnx/float32_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,36 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict

from collections import defaultdict
import onnx

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import ONNXModelHandler
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.pass_config import PassConfigParam
import re

class OnnxIOFloat16ToFloat32(Pass):
"""Converts float16 model inputs/outputs to float32.
logger = logging.getLogger(__name__)


"""
class OnnxIOFloat16ToFloat32(Pass):
"""Converts float16 model inputs/outputs to float32."""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
config = {
"name_pattern": PassConfigParam(
type_=str, default_value="logits",
type_=str,
default_value="logits",
description=(
"Only convert inputs/outputs whose name matches this pattern. By default"
"looking for logits names"
)
"Only convert inputs/outputs whose name matches this pattern. By defaultlooking for logits names"
),
)
}
config.update(get_external_data_config())
Expand All @@ -40,7 +43,7 @@ def create_io_mapping(self, graph, i_map, o_map):
i_map[i].append(n)
for n in graph.node:
for o in n.output:
assert o not in o_map[o]
assert o not in o_map
o_map[o] = [n]

def wrap_inputs(self, graph, i_map, names):
Expand All @@ -54,7 +57,7 @@ def wrap_inputs(self, graph, i_map, names):
match = names.search(i.name)
if not match:
continue
print(f"input {i.name} from fp32")
logger.debug("input %s from fp32", i.name)
for n in i_map[i.name]:
for j, o in enumerate(n.input):
if o == i.name:
Expand All @@ -63,12 +66,11 @@ def wrap_inputs(self, graph, i_map, names):
"Cast",
inputs=[i.name],
outputs=[i.name + "_fp16"],
to=onnx.TensorProto.FLOAT16,
to=onnx.TensorProto.FLOAT,
)
graph.node.insert(0, cast)
i.type.tensor_type.elem_type = onnx.TensorProto.FLOAT


def wrap_outputs(self, graph, i_map, o_map, names):
# 1. find fp16 outputs
# 2. rewrite all providers
Expand All @@ -80,7 +82,7 @@ def wrap_outputs(self, graph, i_map, o_map, names):
match = names.search(o.name)
if not match:
continue
print(f"output {o.name} to fp32")
logger.debug("output %s from fp32", o.name)
for n in o_map[o.name]:
for j, i in enumerate(n.output):
if i == o.name:
Expand Down
47 changes: 47 additions & 0 deletions test/unit_test/passes/onnx/test_float32_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# -------------------------------------------------------------------------

Check warning

Code scanning / lintrunner

RUFF/format Warning test

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning test

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

Final newline expected
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from olive.model.handler.onnx import ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.float32_conversion import OnnxIOFloat16ToFloat32
from test.unit_test.utils import get_onnx_model

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'get_onnx_model' is not used.

Check warning

Code scanning / lintrunner

RUFF/F401 Warning test

test.unit\_test.utils.get\_onnx\_model imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note test

standard import "test.unit_test.utils.get_onnx_model" should be placed before first party imports "olive.model.handler.onnx.ONNXModelHandler", "olive.passes.olive_pass.create_pass_from_dict", "olive.passes.onnx.float32_conversion.OnnxIOFloat16ToFloat32" (wrong-import-order)
See wrong-import-order.

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning test

Unused get_onnx_model imported from test.unit_test.utils (unused-import)
See unused-import.
import onnx

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note test

third party import "onnx" should be placed before first party imports "olive.model.handler.onnx.ONNXModelHandler", "olive.passes.olive_pass.create_pass_from_dict", "olive.passes.onnx.float32_conversion.OnnxIOFloat16ToFloat32" (wrong-import-order)
See wrong-import-order.
from onnx import helper

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note test

third party import "onnx.helper" should be placed before first party imports "olive.model.handler.onnx.ONNXModelHandler", "olive.passes.olive_pass.create_pass_from_dict", "olive.passes.onnx.float32_conversion.OnnxIOFloat16ToFloat32" (wrong-import-order)
See wrong-import-order.
from onnx import TensorProto

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note test

third party import "onnx.TensorProto" should be placed before first party imports "olive.model.handler.onnx.ONNXModelHandler", "olive.passes.olive_pass.create_pass_from_dict", "olive.passes.onnx.float32_conversion.OnnxIOFloat16ToFloat32" (wrong-import-order)
See wrong-import-order.


def test_onnx_io_ft16_to_ft32_conversion(tmp_path):
# setup
node1 = helper.make_node(
'Add',

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
['logits_A', 'logits_B'],

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
['logits_C'],

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
name='add_node'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
)

input_tensor_A = helper.make_tensor_value_info('logits_A', TensorProto.FLOAT16, [None])

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable input\_tensor\_A in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
input_tensor_B = helper.make_tensor_value_info('logits_B', TensorProto.FLOAT16, [None])

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable input\_tensor\_B in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
output_tensor_C = helper.make_tensor_value_info('logits_C', TensorProto.FLOAT16, [None])

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable output\_tensor\_C in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

graph = helper.make_graph(
[node1],

Check failure

Code scanning / lintrunner

SPACES/trailing spaces Error test

This line has trailing spaces; please remove them.

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

Trailing whitespace
'example_graph',

Check failure

Code scanning / lintrunner

SPACES/trailing spaces Error test

This line has trailing spaces; please remove them.

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

Trailing whitespace
[input_tensor_A, input_tensor_B],

Check failure

Code scanning / lintrunner

SPACES/trailing spaces Error test

This line has trailing spaces; please remove them.

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

Trailing whitespace
[output_tensor_C]
)
onnx_model = helper.make_model(graph, producer_name='example_producer')

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning test

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
tmp_model_path = str(tmp_path / "model.onnx")
onnx.save(onnx_model, tmp_model_path)
input_model = ONNXModelHandler(model_path=tmp_model_path)
p = create_pass_from_dict(OnnxIOFloat16ToFloat32, None, disable_search=True)
output_folder = str(tmp_path / "onnx")

# execute
output_model = p.run(input_model, None, output_folder)

# assert
for input in output_model.get_graph().input:

Check warning

Code scanning / lintrunner

RUFF/A001 Warning test

Variable input is shadowing a Python builtin.
See https://docs.astral.sh/ruff/rules/builtin-variable-shadowing

Check warning

Code scanning / lintrunner

PYLINT/W0622 Warning test

Redefining built-in 'input' (redefined-builtin)
See redefined-builtin.
assert input.type.tensor_type.elem_type == onnx.TensorProto.FLOAT
for output in output_model.get_graph().output:
assert output.type.tensor_type.elem_type == onnx.TensorProto.FLOAT

Check warning

Code scanning / lintrunner

RUFF/W292 Warning test

0 comments on commit d7991ac

Please sign in to comment.