-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extend CoreML: ReshapeStatic/LoadConstantND (#430)
Signed-off-by: Islam <ibadreldin@icloud.com> Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
- Loading branch information
Showing
6 changed files
with
123 additions
and
0 deletions.
There are no files selected for viewing
21 changes: 21 additions & 0 deletions
21
onnxmltools/convert/coreml/operator_converters/neural_network/LoadConstantND.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from .....proto import helper | ||
from .....proto import onnx_proto | ||
from ....common._registration import register_converter | ||
from ....common._apply_operation import apply_constant | ||
|
||
def convert_load_constant_nd(scope, operator, container): | ||
params = operator.raw_operator.loadConstantND | ||
constant_name = scope.get_unique_variable_name('constant') | ||
constant = helper.make_tensor(constant_name, onnx_proto.TensorProto.FLOAT, | ||
params.shape, params.data.floatValue) | ||
|
||
apply_constant(scope, operator.output_full_names, container, | ||
operator_name=operator.full_name, value=constant) | ||
|
||
register_converter('loadConstantND', convert_load_constant_nd) |
37 changes: 37 additions & 0 deletions
37
onnxmltools/convert/coreml/operator_converters/neural_network/ReshapeStatic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from ....common._apply_operation import apply_reshape | ||
from ....common._registration import register_converter | ||
|
||
|
||
def convert_reshape_static(scope, operator, container): | ||
from coremltools.proto.NeuralNetwork_pb2 import ReshapeLayerParams as Params | ||
|
||
params = operator.raw_operator.reshapeStatic | ||
|
||
# print(params) | ||
intra_variable_name = operator.inputs[0].full_name | ||
|
||
N = operator.inputs[0].type.shape[0] | ||
if N == 'None': | ||
N = -1 | ||
if len(params.targetShape) == 4: | ||
output_shape = [int(d) for d in params.targetShape] | ||
output_shape[0] = N # Overwrite bad default CoreML setting | ||
elif len(params.targetShape) == 3: | ||
output_shape = [N] + [int(d) for d in params.targetShape] | ||
elif len(params.targetShape) == 2: | ||
output_shape = [N] + [int(d) for d in params.targetShape] | ||
else: | ||
raise ValueError('The targeted shape of Reshape (name: %s) must be 3-element or 4-element array but got %s'\ | ||
% (operator.full_name, params.targetShape)) | ||
|
||
apply_reshape(scope=scope, input_name=intra_variable_name, output_name=operator.outputs[0].full_name, | ||
container=container, operator_name=operator.full_name, desired_shape=output_shape) | ||
|
||
|
||
register_converter('reshapeStatic', convert_reshape_static) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
onnxmltools/convert/coreml/shape_calculators/neural_network/LoadConstantND.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
#-------------------------------------------------------------------------- | ||
|
||
from ....common._registration import register_shape_calculator | ||
from ....common.data_types import TensorType, FloatTensorType | ||
from ....common.utils import check_input_and_output_numbers | ||
|
||
def calculate_load_constant_nd_output_shapes(operator): | ||
check_input_and_output_numbers(operator, input_count_range=None, output_count_range=1) | ||
|
||
output = operator.outputs[0] | ||
|
||
# CoreML's constant is always 3-D tensor, so we assume its shape is [C, H, W]. | ||
const_shape = operator.raw_operator.loadConstantND.shape | ||
# We convert [C, H, W] to [1, C, H, W] because our parsing code use [N, C, H, W] | ||
const_shape = [1] + [int(d) for d in const_shape] | ||
if output.type is None: | ||
# Use default type | ||
output.type = FloatTensorType(const_shape, doc_string=output.type.doc_string) | ||
else: | ||
if not isinstance(output.type, TensorType): | ||
raise RuntimeError('Type conflict detected. Output must be a tensor.') | ||
# If output type exists, we just modify its shape. | ||
output.type.shape = const_shape | ||
|
||
|
||
register_shape_calculator('loadConstantND', calculate_load_constant_nd_output_shapes) |
31 changes: 31 additions & 0 deletions
31
onnxmltools/convert/coreml/shape_calculators/neural_network/ReshapeStatic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from ....common._registration import register_shape_calculator | ||
from ....common.data_types import FloatTensorType | ||
from ....common.utils import check_input_and_output_numbers, check_input_and_output_types | ||
|
||
def calculate_reshape_static_output_shapes(operator): | ||
''' | ||
Allowed input/output patterns are | ||
1. [N, C, H, W] ---> [N, C', H', W'] | ||
Note that C*H*W should equal to C'*H'*W'. | ||
''' | ||
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1) | ||
check_input_and_output_types(operator, good_input_types=[FloatTensorType]) | ||
|
||
params = operator.raw_operator.reshapeStatic | ||
|
||
output_shape = list(int(i) for i in params.targetShape) | ||
|
||
if len(output_shape) == 3: | ||
output_shape = [operator.inputs[0].type.shape[0]] + output_shape | ||
|
||
operator.outputs[0].type.shape = output_shape | ||
|
||
|
||
register_shape_calculator('reshapeStatic', calculate_reshape_static_output_shapes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters