From 0d2c11e98eaabe1878a56ca687736dcdce2e115c Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Fri, 15 May 2020 12:15:35 -0700 Subject: [PATCH 1/6] add tf.where and logical ops supports. --- keras2onnx/_builtin.py | 29 ++++++++++++------------- keras2onnx/_tf_ops.py | 47 +++++++++++++++++++++++++++++++++-------- keras2onnx/main.py | 9 ++++++-- keras2onnx/topology.py | 9 +++++--- tests/test_tf2_keras.py | 21 +++++++++++++++--- 5 files changed, 82 insertions(+), 33 deletions(-) diff --git a/keras2onnx/_builtin.py b/keras2onnx/_builtin.py index c6a8f76b..e0e57ac3 100644 --- a/keras2onnx/_builtin.py +++ b/keras2onnx/_builtin.py @@ -639,8 +639,6 @@ def convert_tf_conv2d(scope, operator, container): @converter_func(TYPES.Einsum) def convert_tf_einsum(scope, operator, container): - if operator.target_opset < 12: - raise ValueError("Einsum op is not supported until opset 12") oopb = OnnxOperatorBuilder(container, scope) node = operator.raw_operator equation_str = node.get_attr('equation').decode("utf-8") @@ -648,7 +646,8 @@ def convert_tf_einsum(scope, operator, container): operator.input_full_names, operator.output_full_names, name=operator.full_name, - equation=equation_str) + equation=equation_str, + op_version=12) @converter_func(TYPES.ExpandDims) @@ -2001,19 +2000,17 @@ def convert_tf_variable_v2(scope, operator, container): @converter_func(TYPES.Where) def convert_tf_where(scope, operator, container): - if operator.target_opset < 9: - raise ValueError("Where op is not supported for opset < 9") - else: - oopb = OnnxOperatorBuilder(container, scope) - node = operator.raw_operator - where_node = oopb.add_node('NonZero', - operator.inputs[0].full_name, - operator.inputs[0].full_name + '_non_zero') - oopb.apply_op_with_output("apply_transpose", - where_node, - operator.output_full_names, - name=operator.full_name + '_transpose', - perm=list(reversed(range(len(node.outputs[0].shape))))) + oopb = OnnxOperatorBuilder(container, scope) + node = operator.raw_operator + where_node = oopb.add_node('NonZero', + operator.inputs[0].full_name, + operator.inputs[0].full_name + '_non_zero', + op_version=9) + oopb.apply_op_with_output("apply_transpose", + where_node, + operator.output_full_names, + name=operator.full_name + '_transpose', + perm=list(reversed(range(len(node.outputs[0].shape))))) @converter_func(TYPES.ZerosLike) diff --git a/keras2onnx/_tf_ops.py b/keras2onnx/_tf_ops.py index ee42b6e6..fd49f661 100644 --- a/keras2onnx/_tf_ops.py +++ b/keras2onnx/_tf_ops.py @@ -5,6 +5,7 @@ from onnxconverter_common.oopb import OnnxOperatorBuilder from .funcbook import converter_func from ._tf_utils import tf_attrs_to_onnx as _to_onnx_attrs +from ._tf_utils import cal_tensor_shape as _cal_tensor_shape def _random_converter(scope, operator, container): @@ -28,13 +29,40 @@ def _random_converter(scope, operator, container): @converter_func( - "RandomNormal", + 'RandomNormal', 'RandomStandardNormal', - "RandomUniform") + 'RandomUniform') def convert_tf_random_standard_normal(scope, operator, container): _random_converter(scope, operator, container) +@converter_func('Select', 'SelectV2') +def convert_tf_select(scope, operator, container): + tf_op = operator.raw_operator + shape_i0 = _cal_tensor_shape(tf_op.inputs[0]) + target_shape = _cal_tensor_shape(tf_op.inputs[1]) + if len(target_shape) == 0: + target_shape = _cal_tensor_shape(tf_op.inputs[2]) + input0 = operator.input_full_names[0] + with OnnxOperatorBuilder(container, scope).as_default(operator.full_name) as oopb: # type: OnnxOperatorBuilder + if len(shape_i0) == 1 and len(target_shape) > 1: + input0 = oopb.unsqueeze(input0, axes=list(range(len(target_shape)))[1:]) + oopb.add_node("Where", [input0] + operator.input_full_names[1:], + outputs=operator.output_full_names, + op_version=9) + + +@converter_func('LogicalNot', 'LogicalAnd', 'LogicalOr') +def convert_tf_logical_ops(scope, operator, container): + onnx_type = operator.type[len('Logical'):] + oopb = OnnxOperatorBuilder(container, scope) + oopb.add_node(onnx_type, + operator.input_full_names, + name=operator.full_name, + outputs=operator.output_full_names, + op_version=1) + + def pass_thru_converter(scope, operator, container): """ This converter is to copy the original graph node with its def into a ONNX node format. @@ -42,10 +70,11 @@ def pass_thru_converter(scope, operator, container): tf_op = operator.raw_operator attrs = _to_onnx_attrs(tf_op) - container.add_node(operator.type, - operator.input_full_names, - operator.output_full_names, - name=operator.full_name, - op_domain='ai.onnx.contrib', - op_version=1, - **attrs) + oopb = OnnxOperatorBuilder(container, scope) + oopb.add_node(operator.type, + operator.input_full_names, + name=operator.full_name, + outputs=operator.output_full_names, + op_domain='ai.onnx.contrib', + op_version=1, + **attrs) diff --git a/keras2onnx/main.py b/keras2onnx/main.py index d060d8bd..daf1a6cc 100644 --- a/keras2onnx/main.py +++ b/keras2onnx/main.py @@ -48,8 +48,13 @@ def convert_keras(model, name=None, doc_string='', target_opset=None, print(model.summary()) name = name or model.name - target_opset = target_opset or get_maximum_opset_supported() - + cvt_default_opset = get_maximum_opset_supported() + if target_opset is None: + target_opset = cvt_default_opset + elif target_opset > cvt_default_opset: + raise RuntimeError( + "The opset {} conversion not support yet, the current maximum opset version supported is {}.".format( + target_opset, cvt_default_opset)) input_names = [] output_names = [] output_dict = {} diff --git a/keras2onnx/topology.py b/keras2onnx/topology.py index 797cb09f..750e2417 100644 --- a/keras2onnx/topology.py +++ b/keras2onnx/topology.py @@ -4,7 +4,7 @@ # license information. ############################################################################### from onnxconverter_common.onnx_ex import make_model_ex -from .common import k2o_logger +from .common import utils, k2o_logger from .common import OnnxObjectContainer, Variable, InterimContext from .common.data_types import TensorType, Int64Type, FloatType, StringType from .funcbook import get_converter @@ -227,7 +227,8 @@ def _remove_unused_nodes(nodes, inputs, outputs): if in_ in output_dict: node_inputs.append(output_dict[in_]) else: - assert in_ == '' or in_ in input_dict + assert in_ == '' or in_ in input_dict, \ + "{} is disconnected, check the parsing log for more details.".format(in_) return [nd_ for nd_ in nodes if id(nd_) in nodes_to_keep] @@ -375,5 +376,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, channel_fir # Create model onnx_model = make_model_ex(graph, container.node_domain_version_pair_sets, - target_opset, doc_string=doc_string) + target_opset, doc_string=doc_string, + producer_name=utils.get_producer(), + domain=utils.get_domain()) return onnx_model diff --git a/tests/test_tf2_keras.py b/tests/test_tf2_keras.py index 554c9b9c..df907954 100644 --- a/tests/test_tf2_keras.py +++ b/tests/test_tf2_keras.py @@ -50,9 +50,9 @@ def call(self, inputs, **kwargs): return output -class DummyModel(tf.keras.Model): +class SimpleWrapperModel(tf.keras.Model): def __init__(self, func): - super(DummyModel, self).__init__() + super(SimpleWrapperModel, self).__init__() self.func = func def call(self, inputs, **kwargs): @@ -88,7 +88,7 @@ def op_func(arg_inputs): x = x - tf.cast(tf.expand_dims(r, axis=0), tf.float32) return x - dm = DummyModel(op_func) + dm = SimpleWrapperModel(op_func) inputs = [tf.random.normal((3, 2, 20)), tf.random.normal((3, 2, 20))] expected = dm.predict(inputs) oxml = keras2onnx.convert_keras(dm) @@ -186,3 +186,18 @@ def test_auto_encoder(runner): # The random generator is not same between different engiens. import onnx onnx.checker.check_model(oxml) + + +def test_tf_where(runner): + def _tf_where(input_0): + a = tf.where(True, input_0, [0, 1, 2, 5, 7]) + b = tf.where([True], tf.expand_dims(input_0, axis=0), tf.expand_dims([0, 1, 2, 5, 7], axis=0)) + c = tf.logical_or(tf.cast(a, tf.bool), tf.cast(b, tf.bool)) + return c + + swm = SimpleWrapperModel(_tf_where) + const_in = [np.array([2, 4, 6, 8, 10])] + expected = swm(const_in) + swm._set_inputs(const_in) + oxml = keras2onnx.convert_keras(swm, debug_mode=True) + assert runner('where_test', oxml, const_in, expected) From 16daada454547318b5784d128f05740a4329b548 Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Fri, 15 May 2020 13:09:06 -0700 Subject: [PATCH 2/6] The converter test version. --- tests/test_layers.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index c46f6e7f..7bb486d5 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1278,6 +1278,8 @@ def test_LeakyReLU(advanced_activation_runner): advanced_activation_runner(layer, data) +@pytest.mark.skipif(get_maximum_opset_supported() < 8, + reason="ThresoldRelu needs ONNX opset 8") def test_ThresholdedReLU(advanced_activation_runner): data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5) layer = advanced_activations.ThresholdedReLU(theta=1.0, input_shape=(data.size,)) @@ -1626,28 +1628,30 @@ def test_LSTM(runner): expected = model.predict(data) assert runner(onnx_model.graph.name, onnx_model, data, expected) + @pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)), reason="keras LSTM does not have time_major attribute") def test_LSTM_time_major_return_seq_true(runner): inputs1 = keras.Input(shape=(3, 5)) data = np.random.rand(1, 3, 5).astype(np.float32) # Transpose input to be time major - input_transposed = tf.transpose(inputs1, perm=[1,0,2]) + input_transposed = tf.transpose(inputs1, perm=[1, 0, 2]) lstm1, state_h, state_c = LSTM(units=2, time_major=True, return_state=True, return_sequences=True)(input_transposed) - lstm1_trans = tf.transpose(lstm1, perm=[1,0,2]) + lstm1_trans = tf.transpose(lstm1, perm=[1, 0, 2]) model = keras.Model(inputs=inputs1, outputs=[lstm1_trans, state_h, state_c]) onnx_model = keras2onnx.convert_keras(model, model.name) expected = model.predict(data) assert runner(onnx_model.graph.name, onnx_model, data, expected) -@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)) , + +@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)), reason="keras LSTM does not have time_major attribute") def test_LSTM_time_major_return_seq_false(runner): inputs1 = keras.Input(shape=(3, 5)) data = np.random.rand(1, 3, 5).astype(np.float32) # Transpose input to be time major - input_transposed = tf.transpose(inputs1, perm=[1,0,2]) + input_transposed = tf.transpose(inputs1, perm=[1, 0, 2]) lstm1, state_h, state_c = LSTM(units=2, time_major=True, return_state=True, return_sequences=False)(input_transposed) model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c]) @@ -1655,6 +1659,7 @@ def test_LSTM_time_major_return_seq_false(runner): expected = model.predict(data) assert runner(onnx_model.graph.name, onnx_model, data, expected) + def test_LSTM_with_bias(runner): inputs1 = keras.Input(shape=(1, 1)) cls = LSTM(units=1, return_state=True, return_sequences=True) From 076126a63f6e5a5923f7a38ee1a035937858c79f Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Fri, 15 May 2020 13:14:11 -0700 Subject: [PATCH 3/6] avoid hitting tf.20 bug --- tests/test_tf2_keras.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_tf2_keras.py b/tests/test_tf2_keras.py index df907954..29ead56c 100644 --- a/tests/test_tf2_keras.py +++ b/tests/test_tf2_keras.py @@ -7,6 +7,7 @@ import keras2onnx import numpy as np import tensorflow as tf +from keras2onnx.proto import is_tensorflow_older_than if (not keras2onnx.proto.is_tf_keras) or (not keras2onnx.proto.tfcompat.is_tf2): pytest.skip("Tensorflow 2.0 only tests.", allow_module_level=True) @@ -188,6 +189,7 @@ def test_auto_encoder(runner): onnx.checker.check_model(oxml) +@pytest.mark.skipif((is_tensorflow_older_than('2.1.0')), 'tf 2.0 has several bug on the following code.') def test_tf_where(runner): def _tf_where(input_0): a = tf.where(True, input_0, [0, 1, 2, 5, 7]) From 3d0cc46d6515a85d9cff417e80d41644fb7825f7 Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Fri, 15 May 2020 13:15:10 -0700 Subject: [PATCH 4/6] one more fixing --- tests/test_tf2_keras.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_tf2_keras.py b/tests/test_tf2_keras.py index 29ead56c..5ef8acbb 100644 --- a/tests/test_tf2_keras.py +++ b/tests/test_tf2_keras.py @@ -189,7 +189,8 @@ def test_auto_encoder(runner): onnx.checker.check_model(oxml) -@pytest.mark.skipif((is_tensorflow_older_than('2.1.0')), 'tf 2.0 has several bug on the following code.') +@pytest.mark.skipif((is_tensorflow_older_than('2.1.0')), + reason='tf 2.0 has several bug on the following code.') def test_tf_where(runner): def _tf_where(input_0): a = tf.where(True, input_0, [0, 1, 2, 5, 7]) From 2f258871915a9e0ee851424461e860c0d02dfb6f Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Fri, 15 May 2020 13:15:44 -0700 Subject: [PATCH 5/6] clean up --- tests/test_tf2_keras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tf2_keras.py b/tests/test_tf2_keras.py index 5ef8acbb..786d158d 100644 --- a/tests/test_tf2_keras.py +++ b/tests/test_tf2_keras.py @@ -202,5 +202,5 @@ def _tf_where(input_0): const_in = [np.array([2, 4, 6, 8, 10])] expected = swm(const_in) swm._set_inputs(const_in) - oxml = keras2onnx.convert_keras(swm, debug_mode=True) + oxml = keras2onnx.convert_keras(swm) assert runner('where_test', oxml, const_in, expected) From b1e14651c5ddd8e1aefd8fd434939b028092bf82 Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Fri, 15 May 2020 13:34:16 -0700 Subject: [PATCH 6/6] really??? --- tests/test_tf2_keras.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_tf2_keras.py b/tests/test_tf2_keras.py index 786d158d..c2ba6223 100644 --- a/tests/test_tf2_keras.py +++ b/tests/test_tf2_keras.py @@ -189,8 +189,6 @@ def test_auto_encoder(runner): onnx.checker.check_model(oxml) -@pytest.mark.skipif((is_tensorflow_older_than('2.1.0')), - reason='tf 2.0 has several bug on the following code.') def test_tf_where(runner): def _tf_where(input_0): a = tf.where(True, input_0, [0, 1, 2, 5, 7]) @@ -199,7 +197,7 @@ def _tf_where(input_0): return c swm = SimpleWrapperModel(_tf_where) - const_in = [np.array([2, 4, 6, 8, 10])] + const_in = [np.array([2, 4, 6, 8, 10]).astype(np.int32)] expected = swm(const_in) swm._set_inputs(const_in) oxml = keras2onnx.convert_keras(swm)