Skip to content

Commit

Permalink
Update RunONNXModel.py (llvm#1154)
Browse files Browse the repository at this point in the history
added option to save model and verify.

Signed-off-by: Tung D. Le <tungld@gmail.com>
Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
tungld and AlexandreEichenberger authored Feb 3, 2022
1 parent 15a5c0b commit fb59d2c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
14 changes: 11 additions & 3 deletions docs/DebuggingNumericalError.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ reference inputs and outputs in protobuf.

```bash
$ python ../utils/RunONNXModel.py --help
usage: RunONNXModel.py [-h] [--print_input] [--print_output]
[--save_so PATH | --load_so PATH] [--save_data PATH]
usage: RunONNXModel.py [-h]
[--print_input]
[--print_output]
[--save_onnx PATH]
[--save_so PATH | --load_so PATH]
[--save_data PATH]
[--data_folder DATA_FOLDER | --shape_info SHAPE_INFO]
[--compile_args COMPILE_ARGS]
[--verify {onnxruntime,ref}]
[--compile_using_input_shape] [--rtol RTOL]
[--verify_all_ops]
[--compile_using_input_shape]
[--rtol RTOL]
[--atol ATOL]
model_path

Expand All @@ -48,6 +54,7 @@ optional arguments:
-h, --help show this help message and exit
--print_input Print out inputs
--print_output Print out inference outputs produced by onnx-mlir
--save_onnx PATH File path to save the onnx model
--save_so PATH File path to save the generated shared library of the
model
--load_so PATH File path to load a generated shared library for
Expand All @@ -68,6 +75,7 @@ optional arguments:
--verify {onnxruntime,ref}
Verify the output by using onnxruntime or reference
inputs/outputs. By default, no verification
--verify_all_ops Verify all operation outputs when using onnxruntime.
--compile_using_input_shape
Compile the model by using the shape info getting from
the inputs in data folder. Must set --data_folder
Expand Down
21 changes: 15 additions & 6 deletions utils/RunONNXModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
parser.add_argument('--print_output',
action='store_true',
help="Print out inference outputs produced by onnx-mlir")
parser.add_argument('--save_onnx',
metavar='PATH',
type=str,
help="File path to save the onnx model")
lib_group.add_argument('--save_so',
metavar='PATH',
type=str,
Expand Down Expand Up @@ -55,6 +59,9 @@
choices=['onnxruntime', 'ref'],
help="Verify the output by using onnxruntime or reference"
" inputs/outputs. By default, no verification")
parser.add_argument('--verify_all_ops',
action='store_true',
help="Verify all operation outputs when using onnxruntime.")
parser.add_argument(
'--compile_using_input_shape',
action='store_true',
Expand Down Expand Up @@ -113,15 +120,11 @@ def execute_commands(cmds):


def extend_model_output(model, intermediate_outputs):
# onnx-mlir doesn't care about manually specified output types & shapes.
DUMMY_TENSOR_TYPE = onnx.TensorProto.FLOAT

while (len(model.graph.output)):
model.graph.output.pop()

for output_name in intermediate_outputs:
output_value_info = onnx.helper.make_tensor_value_info(
output_name, DUMMY_TENSOR_TYPE, None)
output_value_info = onnx.helper.make_empty_tensor_value_info(output_name)
model.graph.output.extend([output_value_info])
return model

Expand Down Expand Up @@ -228,12 +231,18 @@ def main():
# If using onnxruntime for verification, we can verify every operation output.
output_names = [o.name for o in model.graph.output]
output_names = list(OrderedDict.fromkeys(output_names))
if (args.verify and args.verify == "onnxruntime"):
if (args.verify and args.verify == "onnxruntime" and args.verify_all_ops):
print("Extending the onnx model to check every node output ...\n")
output_names = sum([[n for n in node.output if n != '']
for node in model.graph.node], [])
output_names = list(OrderedDict.fromkeys(output_names))
model = extend_model_output(model, output_names)

# Save the generated .so file of the model if required.
if (args.save_onnx):
print("Saving the onnx model to ", args.save_onnx, "\n")
onnx.save(model, args.save_onnx)

# Compile, run, and verify.
with tempfile.TemporaryDirectory() as temp_dir:
print("Temporary directory has been created at {}".format(temp_dir))
Expand Down

0 comments on commit fb59d2c

Please sign in to comment.