Skip to content

Commit

Permalink
Extend CoreML: ReshapeStatic/LoadConstantND (#430)
Browse files Browse the repository at this point in the history
Signed-off-by: Islam <ibadreldin@icloud.com>

Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
  • Loading branch information
ibadr and wenbingl authored Jan 7, 2021
1 parent a849358 commit 7631c1a
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from .....proto import helper
from .....proto import onnx_proto
from ....common._registration import register_converter
from ....common._apply_operation import apply_constant

def convert_load_constant_nd(scope, operator, container):
params = operator.raw_operator.loadConstantND
constant_name = scope.get_unique_variable_name('constant')
constant = helper.make_tensor(constant_name, onnx_proto.TensorProto.FLOAT,
params.shape, params.data.floatValue)

apply_constant(scope, operator.output_full_names, container,
operator_name=operator.full_name, value=constant)

register_converter('loadConstantND', convert_load_constant_nd)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from ....common._apply_operation import apply_reshape
from ....common._registration import register_converter


def convert_reshape_static(scope, operator, container):
from coremltools.proto.NeuralNetwork_pb2 import ReshapeLayerParams as Params

params = operator.raw_operator.reshapeStatic

# print(params)
intra_variable_name = operator.inputs[0].full_name

N = operator.inputs[0].type.shape[0]
if N == 'None':
N = -1
if len(params.targetShape) == 4:
output_shape = [int(d) for d in params.targetShape]
output_shape[0] = N # Overwrite bad default CoreML setting
elif len(params.targetShape) == 3:
output_shape = [N] + [int(d) for d in params.targetShape]
elif len(params.targetShape) == 2:
output_shape = [N] + [int(d) for d in params.targetShape]
else:
raise ValueError('The targeted shape of Reshape (name: %s) must be 3-element or 4-element array but got %s'\
% (operator.full_name, params.targetShape))

apply_reshape(scope=scope, input_name=intra_variable_name, output_name=operator.outputs[0].full_name,
container=container, operator_name=operator.full_name, desired_shape=output_shape)


register_converter('reshapeStatic', convert_reshape_static)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import InnerProduct
from . import L2Normalize
from . import LoadConstant
from . import LoadConstantND
from . import LRN
from . import LSTM
from . import Max
Expand All @@ -35,6 +36,7 @@
from . import Reduce
from . import ReorganizeData
from . import Reshape
from . import ReshapeStatic
from . import Scale
from . import SequenceRepeat
from . import SimpleRNN
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------

from ....common._registration import register_shape_calculator
from ....common.data_types import TensorType, FloatTensorType
from ....common.utils import check_input_and_output_numbers

def calculate_load_constant_nd_output_shapes(operator):
check_input_and_output_numbers(operator, input_count_range=None, output_count_range=1)

output = operator.outputs[0]

# CoreML's constant is always 3-D tensor, so we assume its shape is [C, H, W].
const_shape = operator.raw_operator.loadConstantND.shape
# We convert [C, H, W] to [1, C, H, W] because our parsing code use [N, C, H, W]
const_shape = [1] + [int(d) for d in const_shape]
if output.type is None:
# Use default type
output.type = FloatTensorType(const_shape, doc_string=output.type.doc_string)
else:
if not isinstance(output.type, TensorType):
raise RuntimeError('Type conflict detected. Output must be a tensor.')
# If output type exists, we just modify its shape.
output.type.shape = const_shape


register_shape_calculator('loadConstantND', calculate_load_constant_nd_output_shapes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from ....common._registration import register_shape_calculator
from ....common.data_types import FloatTensorType
from ....common.utils import check_input_and_output_numbers, check_input_and_output_types

def calculate_reshape_static_output_shapes(operator):
'''
Allowed input/output patterns are
1. [N, C, H, W] ---> [N, C', H', W']
Note that C*H*W should equal to C'*H'*W'.
'''
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
check_input_and_output_types(operator, good_input_types=[FloatTensorType])

params = operator.raw_operator.reshapeStatic

output_shape = list(int(i) for i in params.targetShape)

if len(output_shape) == 3:
output_shape = [operator.inputs[0].type.shape[0]] + output_shape

operator.outputs[0].type.shape = output_shape


register_shape_calculator('reshapeStatic', calculate_reshape_static_output_shapes)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import IdentityFloat
from . import InnerProduct
from . import LoadConstant
from . import LoadConstantND
from . import LSTM
from . import Merge
from . import Pad
Expand All @@ -25,6 +26,7 @@
from . import Reduce
from . import ReorganizeData
from . import Reshape
from . import ReshapeStatic
from . import SequenceRepeat
from . import Slice
from . import Split
Expand Down

0 comments on commit 7631c1a

Please sign in to comment.