Skip to content

Commit

Permalink
Merge branch 'support_int32_output_data_type' into 'master'
Browse files Browse the repository at this point in the history
feat: Support the case of output data type is int32

See merge request applied-machine-learning/sysml/mace!1474
  • Loading branch information
lu229 committed Jan 13, 2022
2 parents 684ded8 + 09f352c commit 8c75d39
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 21 deletions.
7 changes: 7 additions & 0 deletions tools/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,12 @@ def run_specify_abi(self, flags, configs, target_abi):
output_config, runtime, tuning)
if flags.validate:
log_file = ""
if YAMLKeyword.output_data_types in output_config:
output_data_types = output_config[
YAMLKeyword.output_data_types]
else:
output_data_types = output_infos[
YAMLKeyword.output_data_types]
if flags.layers != "-1":
log_file = log_dir + "/log.csv"
model_file_path, weight_file_path = \
Expand Down Expand Up @@ -954,6 +960,7 @@ def run_specify_abi(self, flags, configs, target_abi):
model_output_dir=model_output_dir,
input_data_types=input_infos[
YAMLKeyword.input_data_types],
output_data_types=output_data_types,
caffe_env=flags.caffe_env,
validation_threshold=model_config[
YAMLKeyword.validation_threshold][
Expand Down
20 changes: 17 additions & 3 deletions tools/python/layers_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def convert(model_file, output_dir, layers):

output_tensors = []
output_shapes = []
output_data_types = []
output_data_formats = []
op_name = op.name
if str(op.name).startswith(MaceKeyword.mace_output_node_name):
continue
Expand All @@ -232,13 +234,25 @@ def convert(model_file, output_dir, layers):
output_tensors.append(str(op.output[i]))
output_shapes.append(
",".join([str(dim) for dim in op.output_shape[i].dims]))
if data_format == DataFormat.NONE.value:
output_data_formats.append(DataFormat.NONE.name)
elif data_format == DataFormat.NCHW.value:
output_data_formats.append(DataFormat.NCHW.name)
else:
output_data_formats.append(DataFormat.NHWC.name)
# modify output info
multi_net.output_tensor.append(op.output[i])
output_info = net.output_info.add()
output_info.name = op.output[i]
output_info.data_format = data_format
output_info.dims.extend(op.output_shape[i].dims)
output_info.data_type = mace_pb2.DT_FLOAT
data_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_op_data_type_str)
output_info.data_type = data_type.i
if mace_pb2.DT_INT32 == data_type.i:
output_data_types.append('int32')
else:
output_data_types.append('float32')
if is_quantize:
output_info.scale = op.quantize_info[0].scale
output_info.zero_point = op.quantize_info[0].zero_point
Expand Down Expand Up @@ -283,8 +297,8 @@ def convert(model_file, output_dir, layers):
output_config = {ModelKeys.model_file_path: str(model_path),
ModelKeys.output_tensors: output_tensors,
ModelKeys.output_shapes: output_shapes,
"output_data_formats":
[DataFormat.NHWC.name] * len(output_shapes)}
"output_data_formats": output_data_formats,
ModelKeys.output_data_types: output_data_types}
output_configs[ModelKeys.subgraphs].append(output_config)

output_configs_path = output_dir + "outputs.yml"
Expand Down
4 changes: 3 additions & 1 deletion tools/sh_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def validate_model(abi,
output_data_formats,
model_output_dir,
input_data_types,
output_data_types,
caffe_env,
input_file_name="model_input",
output_file_name="model_out",
Expand Down Expand Up @@ -668,7 +669,8 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_data_formats), ",".join(output_data_formats),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend,
validation_threshold, ",".join(input_data_types),
",".join(output_data_types), backend,
validation_outputs_data,
log_file)
elif platform == "caffe":
Expand Down
45 changes: 28 additions & 17 deletions tools/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def validate_tf_model(platform, device_type, model_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
validation_threshold, input_data_types,
output_data_types, log_file):
import tensorflow as tf
if not os.path.isfile(model_file):
common.MaceLogger.error(
Expand Down Expand Up @@ -249,7 +250,8 @@ def validate_tf_model(platform, device_type, model_file,
for i in range(len(output_names)):
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
mace_out_value = load_data(output_file_name,
output_data_types[i])
mace_out_value, real_out_shape, real_out_data_format = \
get_real_out_value_shape_df(platform,
mace_out_value,
Expand All @@ -265,7 +267,8 @@ def validate_pytorch_model(platform, device_type, model_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
validation_threshold, input_data_types,
output_data_types, log_file):
import torch
loaded_model = torch.jit.load(model_file)
pytorch_inputs = []
Expand Down Expand Up @@ -293,7 +296,7 @@ def validate_pytorch_model(platform, device_type, model_file,
value = pytorch_outputs[i].numpy()
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
mace_out_value = load_data(output_file_name, output_data_types[i])
mace_out_value, real_output_shape, real_output_data_format = \
get_real_out_value_shape_df(platform,
mace_out_value,
Expand Down Expand Up @@ -381,7 +384,7 @@ def validate_onnx_model(platform, device_type, model_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types,
backend, log_file):
output_data_types, backend, log_file):
print("validate on onnxruntime.")
import onnx
import onnxruntime as onnxrt
Expand Down Expand Up @@ -420,7 +423,7 @@ def validate_onnx_model(platform, device_type, model_file,
value = output_values[i].flatten()
output_file_name = common.formatted_file_name(mace_out_file,
output_names[i])
mace_out_value = load_data(output_file_name)
mace_out_value = load_data(output_file_name, output_data_types[i])
mace_out_value, real_output_shape, real_output_data_format = \
get_real_out_value_shape_df(platform,
mace_out_value,
Expand All @@ -436,7 +439,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
input_data_formats, output_names, output_shapes,
output_data_formats, validation_threshold,
input_data_types, log_file):
input_data_types, output_data_types, log_file):
import megengine._internal as mgb

if not os.path.isfile(model_file):
Expand Down Expand Up @@ -469,7 +472,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
for i in range(len(output_names)):
output_file_name = \
common.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
mace_out_value = load_data(output_file_name, output_data_types[i])
mace_out_value, real_output_shape, real_output_data_format = \
get_real_out_value_shape_df(platform,
mace_out_value,
Expand All @@ -484,7 +487,8 @@ def validate_keras_model(platform, device_type, model_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
validation_threshold, input_data_types,
output_data_types, log_file):
from tensorflow import keras
import tensorflow_model_optimization as tfmot

Expand Down Expand Up @@ -516,7 +520,7 @@ def validate_keras_model(platform, device_type, model_file,
for i in range(len(output_names)):
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
mace_out_value = load_data(output_file_name, output_data_types[i])
mace_out_value, real_output_shape, real_output_data_format = \
get_real_out_value_shape_df(platform,
mace_out_value,
Expand All @@ -531,7 +535,7 @@ def validate_keras_model(platform, device_type, model_file,
def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_data_format_str,
output_data_format_str, input_node, output_node,
validation_threshold, input_data_type, backend,
validation_threshold, input_data_type, output_data_type, backend,
validation_outputs_data, log_file):
input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')]
Expand All @@ -547,6 +551,11 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
if output_data_type:
output_data_types = [data_type
for data_type in output_data_type.split(',')]
else:
output_data_types = ['float32'] * len(output_shapes)
output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes)
if not isinstance(validation_outputs_data, list):
Expand All @@ -569,14 +578,14 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types,
log_file)
output_data_types, log_file)
elif platform == 'pytorch':
validate_pytorch_model(platform, device_type,
model_file, input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes,
output_data_formats, validation_threshold,
input_data_types, log_file)
input_data_types, output_data_types, log_file)
elif platform == 'caffe':
validate_caffe_model(platform, device_type, model_file,
input_file, mace_out_file, weight_file,
Expand All @@ -588,8 +597,8 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold,
input_data_types, backend, log_file)
validation_threshold, input_data_types,
output_data_types, backend, log_file)
elif platform == 'megengine':
validate_megengine_model(platform, device_type, model_file,
input_file, mace_out_file,
Expand All @@ -598,7 +607,8 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_names, output_shapes,
output_data_formats,
validation_threshold,
input_data_types, log_file)
input_data_types,
output_data_types, log_file)
elif platform == 'keras':
validate_keras_model(platform, device_type, model_file,
input_file, mace_out_file,
Expand All @@ -607,7 +617,8 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_names, output_shapes,
output_data_formats,
validation_threshold,
input_data_types, log_file)
input_data_types,
output_data_types, log_file)


def parse_args():
Expand Down

0 comments on commit 8c75d39

Please sign in to comment.