From 5818ae3725cbee2d2c77bdad015fcc062d195216 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Wed, 25 Mar 2020 22:42:38 -0400 Subject: [PATCH 01/12] Adding initial support for outputing states in RNN and GRU --- keras2onnx/ke2onnx/simplernn.py | 10 ++++++++-- tests/test_layers.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 6a0261bd..6ba368b0 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -230,11 +230,17 @@ def build_output(scope, operator, container, output_names, bidirectional=False): is_static_shape = seq_length is not None output_seq = forward_layer.return_sequences output_state = forward_layer.return_state - if output_state: - raise ValueError('Keras Bidirectional cannot return hidden and cell states') oopb = OnnxOperatorBuilder(container, scope) + if output_state: + state_names = [o.full_name for o in operator.outputs[1:]] + intermediate_names = ['{}_{}'.format(rnn_h, i) for i, _ in enumerate(state_names)] + + apply_split(scope, rnn_h, intermediate_names, container) + for intermediate_name, state_name in zip(intermediate_names, state_names): + apply_squeeze(scope, intermediate_name, state_name, container) + # Define seq_dim if not is_static_shape: input_name = operator.inputs[0].full_name diff --git a/tests/test_layers.py b/tests/test_layers.py index d1cf8b17..27de79c4 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1574,6 +1574,30 @@ def test_Bidirectional_with_bias(self): onnx_model = keras2onnx.convert_keras(model, model.name) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + def test_Bidirectional_with_initial_states(self): + for rnn_class in [SimpleRNN, GRU]: + input1 = Input(shape=(None, 5)) + states = Bidirectional(rnn_class(2, return_state=True))(input1) + model = Model(input1, states) + + x = np.random.uniform(0.1, 1.0, size=(4, 3, 5)).astype(np.float32) + #inputs = [x, x] + inputs = [x] + + expected = model.predict(inputs) + onnx_model = keras2onnx.convert_keras(model, model.name) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) + + #input2 = Input(shape=(None, 5)) + #states = Bidirectional(rnn_class(2, return_state=True))(input1)[1:] + #out = Bidirectional(rnn_class(2, return_sequences=True))(input2, initial_state=states) + #model = Model([input1, input2], out) + #inputs = [x, x] + + #expected = model.predict(inputs) + #onnx_model = keras2onnx.convert_keras(model, model.name) + #self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) + # Bidirectional LSTM with seq_length = None @unittest.skipIf(get_opset_number_from_onnx() < 9, "None seq_length Bidirectional LSTM is not supported before opset 9.") From 81ca2092382e0206bbd65eb3738166942577b16a Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 14:01:09 -0400 Subject: [PATCH 02/12] Adding initial state support for Bidirectional LSTMs --- keras2onnx/ke2onnx/gru.py | 1 + keras2onnx/ke2onnx/lstm.py | 49 ++++++++++++++++++++++++++++----- keras2onnx/ke2onnx/simplernn.py | 38 +++++++++++++++++-------- tests/test_layers.py | 2 +- 4 files changed, 71 insertions(+), 19 deletions(-) diff --git a/keras2onnx/ke2onnx/gru.py b/keras2onnx/ke2onnx/gru.py index 761b332f..0a090776 100644 --- a/keras2onnx/ke2onnx/gru.py +++ b/keras2onnx/ke2onnx/gru.py @@ -161,3 +161,4 @@ def convert_keras_gru(scope, operator, container, bidirectional=False): **attrs) simplernn.build_output(scope, operator, container, output_names, bidirectional) + simplernn.build_output_states(scope, operator, container, output_names, bidirectional) diff --git a/keras2onnx/ke2onnx/lstm.py b/keras2onnx/ke2onnx/lstm.py index 80deab1f..ac92c47b 100644 --- a/keras2onnx/ke2onnx/lstm.py +++ b/keras2onnx/ke2onnx/lstm.py @@ -11,6 +11,7 @@ apply_identity, apply_reshape, apply_slice, + apply_split, apply_squeeze, apply_transpose, OnnxOperatorBuilder @@ -179,7 +180,7 @@ def build_attributes(scope, operator, container, bidirectional=False): return attrs def build_output(scope, operator, container, output_names, bidirectional=False): - """ + """Builds the output operators for the LSTM layer. """ if bidirectional: return simplernn.build_output(scope, operator, container, output_names[:-1], bidirectional) @@ -189,7 +190,6 @@ def build_output(scope, operator, container, output_names, bidirectional=False): op = operator.raw_operator hidden_size = op.units output_seq = op.return_sequences - output_state = op.return_state _, seq_length, input_size = simplernn.extract_input_shape(op) is_static_shape = seq_length is not None @@ -229,10 +229,44 @@ def build_output(scope, operator, container, output_names, bidirectional=False): else: apply_reshape(scope, lstm_h, output_name, container, desired_shape=[-1, hidden_size]) - if output_state: - # Output hidden and cell states - apply_reshape(scope, lstm_h, operator.outputs[1].full_name, container, desired_shape=[-1, hidden_size]) - apply_reshape(scope, lstm_c, operator.outputs[2].full_name, container, desired_shape=[-1, hidden_size]) + +def build_output_states(scope, operator, container, output_names, bidirectional=False): + """Builds the output hidden states for the LSTM layer. + """ + _, lstm_h, lstm_c = output_names + op = operator.raw_operator + + if bidirectional: + forward_layer = op.forward_layer + output_state = forward_layer.return_state + + if not output_state: + return + + # Split lstm_h and lstm_c into forward and backward components + squeeze_names = [] + output_names = [o.full_name for o in operator.outputs[1:]] + name_map = {lstm_h: output_names[::2], lstm_c: output_names[1::2]} + + for state_name, outputs in name_map.items(): + split_names = ['{}_{}'.format(state_name, d) for d in ('forward', 'backward')] + + apply_split(scope, state_name, split_names, container) + squeeze_names.extend(list(zip(split_names, outputs))) + + for split_name, output_name in squeeze_names: + apply_squeeze(scope, split_name, output_name, container) + + else: + output_state = op.return_state + + if not output_state: + return + + output_h = operator.outputs[1].full_name + output_c = operator.outputs[2].full_name + apply_squeeze(scope, lstm_h, output_h, container) + apply_squeeze(scope, lstm_c, output_c, container) def _calculate_keras_lstm_output_shapes(operator): @@ -254,7 +288,7 @@ def convert_keras_lstm(scope, operator, container, bidirectional=False): else: output_seq = op.return_sequences - check_sequence_lengths(operator, container) + #check_sequence_lengths(operator, container) # Inputs lstm_x = _name('X') @@ -292,3 +326,4 @@ def convert_keras_lstm(scope, operator, container, bidirectional=False): **attrs) build_output(scope, operator, container, output_names, bidirectional) + build_output_states(scope, operator, container, output_names, bidirectional) diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 6ba368b0..3b7a7f00 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -229,18 +229,9 @@ def build_output(scope, operator, container, output_names, bidirectional=False): hidden_size = forward_layer.units is_static_shape = seq_length is not None output_seq = forward_layer.return_sequences - output_state = forward_layer.return_state oopb = OnnxOperatorBuilder(container, scope) - if output_state: - state_names = [o.full_name for o in operator.outputs[1:]] - intermediate_names = ['{}_{}'.format(rnn_h, i) for i, _ in enumerate(state_names)] - - apply_split(scope, rnn_h, intermediate_names, container) - for intermediate_name, state_name in zip(intermediate_names, state_names): - apply_squeeze(scope, intermediate_name, state_name, container) - # Define seq_dim if not is_static_shape: input_name = operator.inputs[0].full_name @@ -383,7 +374,6 @@ def build_output(scope, operator, container, output_names, bidirectional=False): else: hidden_size = op.units output_seq = op.return_sequences - output_state = op.return_state output_name = operator.outputs[0].full_name tranposed_y = scope.get_unique_variable_name(operator.full_name + '_y_transposed') @@ -398,8 +388,33 @@ def build_output(scope, operator, container, output_names, bidirectional=False): apply_transpose(scope, rnn_h, tranposed_y, container, perm=[1, 0, 2]) apply_reshape(scope, tranposed_y, output_name, container, desired_shape=[-1, hidden_size]) + +def build_output_states(scope, operator, container, output_names, bidirectional=False): + """Builds the output hidden states for the RNN layer. + """ + _, rnn_h = output_names + op = operator.raw_operator + + if bidirectional: + forward_layer = op.forward_layer + output_state = forward_layer.return_state + + if output_state: + # Split rnn_h into forward and backward directions + output_names = [o.full_name for o in operator.outputs[1:]] + split_names = ['{}_{}'.format(rnn_h, d) for d in ('forward', 'backward')] + + apply_split(scope, rnn_h, split_names, container) + + for split_name, output_name in zip(split_names, output_names): + apply_squeeze(scope, split_name, output_name, container) + + else: + output_state = op.return_state + if output_state: - apply_reshape(scope, rnn_h, operator.outputs[1].full_name, container, desired_shape=[-1, hidden_size]) + output_h = operator.outputs[1].full_name + apply_squeeze(scope, rnn_h, output_h, container) def convert_keras_simple_rnn(scope, operator, container, bidirectional=False): @@ -446,3 +461,4 @@ def convert_keras_simple_rnn(scope, operator, container, bidirectional=False): **attrs) build_output(scope, operator, container, output_names, bidirectional) + build_output_states(scope, operator, container, output_names, bidirectional) diff --git a/tests/test_layers.py b/tests/test_layers.py index 27de79c4..dceb33c0 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1575,7 +1575,7 @@ def test_Bidirectional_with_bias(self): self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) def test_Bidirectional_with_initial_states(self): - for rnn_class in [SimpleRNN, GRU]: + for rnn_class in [SimpleRNN, GRU, LSTM]: input1 = Input(shape=(None, 5)) states = Bidirectional(rnn_class(2, return_state=True))(input1) model = Model(input1, states) From d83f47c6ae6da4686053810d163fffccee2816a3 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 17:11:35 -0400 Subject: [PATCH 03/12] Improving the output for RNN --- keras2onnx/ke2onnx/simplernn.py | 55 +++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 3b7a7f00..66e341eb 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -220,26 +220,26 @@ def build_output(scope, operator, container, output_names, bidirectional=False): op = operator.raw_operator _, seq_length, input_size = extract_input_shape(op) + is_static_shape = seq_length is not None _name = name_func(scope, operator) + oopb = OnnxOperatorBuilder(container, scope) + + # Define seq_dim + if not is_static_shape: + input_name = operator.inputs[0].full_name + input_shape_tensor = oopb.add_node('Shape', [input_name], input_name + '_input_shape_tensor') + + seq_dim = input_name + '_seq_dim' + apply_slice(scope, input_shape_tensor, seq_dim, container, [1], [2], axes=[0]) + if bidirectional: forward_layer = op.forward_layer hidden_size = forward_layer.units - is_static_shape = seq_length is not None output_seq = forward_layer.return_sequences - oopb = OnnxOperatorBuilder(container, scope) - - # Define seq_dim - if not is_static_shape: - input_name = operator.inputs[0].full_name - input_shape_tensor = oopb.add_node('Shape', [input_name], input_name + '_input_shape_tensor') - - seq_dim = input_name + '_seq_dim' - apply_slice(scope, input_shape_tensor, seq_dim, container, [1], [2], axes=[0]) - merge_concat = False if hasattr(op, 'merge_mode'): if op.merge_mode not in ['concat', None]: @@ -376,17 +376,34 @@ def build_output(scope, operator, container, output_names, bidirectional=False): output_seq = op.return_sequences output_name = operator.outputs[0].full_name - tranposed_y = scope.get_unique_variable_name(operator.full_name + '_y_transposed') + transposed_y = scope.get_unique_variable_name(operator.full_name + '_y_transposed') + # Determine the source, transpose permutation, and output shape if output_seq: - perm = [1, 0, 2] if container.target_opset <= 5 else [2, 0, 1, 3] - apply_transpose(scope, rnn_y, tranposed_y, container, perm=perm) - apply_reshape(scope, tranposed_y, output_name, container, - desired_shape=[-1, seq_length, hidden_size]) + source = rnn_y + perm = [2, 0, 1, 3] + if is_static_shape: + desired_shape = [-1, seq_length, hidden_size] + elif container.target_opset < 5: + # Before Reshape-5 you can not take the sequence dimension in as an input + raise ValueError('At least opset 5 is required for output sequences') + else: + # Dynamically determine the output shape based on the sequence dimension + shape_values = [ + ('_a', oopb.int64, np.array([-1], dtype='int64')), + seq_dim, + ('_b', oopb.int64, np.array([hidden_size], dtype='int64')), + ] + shape_name = _name('_output_seq_shape') + desired_shape = oopb.add_node('Concat', shape_values, shape_name, axis=0) else: - # Here we ingore ONNX RNN's first output because it's useless. - apply_transpose(scope, rnn_h, tranposed_y, container, perm=[1, 0, 2]) - apply_reshape(scope, tranposed_y, output_name, container, desired_shape=[-1, hidden_size]) + # Use the last hidden states directly + source = rnn_h + perm = [1, 0, 2] + desired_shape = [-1, hidden_size] + + apply_transpose(scope, source, transposed_y, container, perm=perm) + apply_reshape(scope, transposed_y, output_name, container, desired_shape=desired_shape) def build_output_states(scope, operator, container, output_names, bidirectional=False): From 3e06b2c563af4ad649cfe33e0c49a82f227d8579 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 17:48:13 -0400 Subject: [PATCH 04/12] Removing check_sequences given improved support above opset 5 --- keras2onnx/ke2onnx/lstm.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/keras2onnx/ke2onnx/lstm.py b/keras2onnx/ke2onnx/lstm.py index ac92c47b..d67354a2 100644 --- a/keras2onnx/ke2onnx/lstm.py +++ b/keras2onnx/ke2onnx/lstm.py @@ -23,18 +23,6 @@ TensorProto = onnx_proto.TensorProto -def check_sequence_lengths(operator, container): - """Raises an exception if the shape is expected to be static, but the sequence lenghts - are not provided. This only applies to opsets below 9. - """ - op = operator.raw_operator - - _, seq_length, input_size = simplernn.extract_input_shape(op) - is_static_shape = seq_length is not None - if not is_static_shape and container.target_opset < 9: - raise ValueError('None seq_length is not supported in opset ' + str(container.target_opset)) - - def convert_ifco_to_iofc(tensor_ifco): """Returns a tensor in input (i), output (o), forget (f), cell (c) ordering. The Keras ordering is ifco, while the ONNX ordering is iofc. @@ -288,8 +276,6 @@ def convert_keras_lstm(scope, operator, container, bidirectional=False): else: output_seq = op.return_sequences - #check_sequence_lengths(operator, container) - # Inputs lstm_x = _name('X') tensor_w, tensor_r, tensor_b = build_parameters(scope, operator, container, bidirectional) From 89fa5c3dd6b83378ce814bb711643bd02f9ec0e9 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 18:25:10 -0400 Subject: [PATCH 05/12] Expanding Bidirectional sequence length unit test for opsets >= 5 --- tests/test_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index dceb33c0..c687d1f7 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1599,8 +1599,8 @@ def test_Bidirectional_with_initial_states(self): #self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) # Bidirectional LSTM with seq_length = None - @unittest.skipIf(get_opset_number_from_onnx() < 9, - "None seq_length Bidirectional LSTM is not supported before opset 9.") + @unittest.skipIf(get_opset_number_from_onnx() < 5, + "None seq_length Bidirectional LSTM is not supported before opset 5.") def test_Bidirectional_seqlen_none(self): for rnn_class in [SimpleRNN, GRU, LSTM]: model = Sequential() From 650de71e117d71f921eabad762086f5f4dc37dfc Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 20:28:32 -0400 Subject: [PATCH 06/12] Expanding LSTM sequence length unit test for opsets >= 5 --- tests/test_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index c687d1f7..c7aba767 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1501,8 +1501,8 @@ def test_LSTM_with_initializer(self): run_onnx_runtime(onnx_model.graph.name, onnx_model, {"inputs": x, 'state_h': sh, 'state_c': sc}, expected, self.model_files)) - @unittest.skipIf(get_opset_number_from_onnx() < 9, - "None seq_length LSTM is not supported before opset 9.") + @unittest.skipIf(get_opset_number_from_onnx() < 5, + "None seq_length LSTM is not supported before opset 5.") def test_LSTM_seqlen_none(self): lstm_dim = 2 data = np.random.rand(1, 5, 1).astype(np.float32) From 2ccb073e2d857f4fa6c6c08aa474a342e90ac184 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 21:02:48 -0400 Subject: [PATCH 07/12] Fixing bug in graph parsing with multiple outputs and adding test to cover RNN initial state passing --- keras2onnx/parser.py | 12 ++++++++++++ tests/test_layers.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/keras2onnx/parser.py b/keras2onnx/parser.py index 7ba987ff..c90e8e08 100644 --- a/keras2onnx/parser.py +++ b/keras2onnx/parser.py @@ -20,6 +20,9 @@ list_input_tensors, list_input_mask, list_output_mask, list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer) +ALLOWED_SHARED_KERAS_TYPES = { + keras.layers.embeddings.Embedding, +} def _find_node(nodes, name): try: @@ -570,6 +573,7 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names) for n_ in model_outputs: q_overall.put_nowait(n_) + visited_layers = set() visited = set() # since the output could be shared among the successor nodes. inference_nodeset = _build_inference_nodeset(graph, model_outputs) keras_nodeset = _build_keras_nodeset(inference_nodeset, keras_node_dict) @@ -581,6 +585,14 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names) nodes = [] layer_key_, model_ = _parse_nodes(graph, inference_nodeset, input_nodes, keras_node_dict, keras_nodeset, node, nodes, varset, visited, q_overall) + + # Only parse Keras layers once (allow certain shared classes) + if layer_key_ in visited_layers: + if not type(layer_key_) in ALLOWED_SHARED_KERAS_TYPES: + continue + else: + visited_layers.add(layer_key_) + if not nodes: # already processed by the _parse_nodes continue diff --git a/tests/test_layers.py b/tests/test_layers.py index c7aba767..9f06995a 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1614,6 +1614,22 @@ def test_Bidirectional_seqlen_none(self): expected = model.predict(x) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + def test_rnn_state_passing(self): + for rnn_class in [SimpleRNN, GRU, LSTM]: + input1 = Input(shape=(None, 5)) + input2 = Input(shape=(None, 5)) + + states = rnn_class(2, return_state=True)(input1)[1:] + out = rnn_class(2, return_sequences=True)(input2, initial_state=states) + model = Model([input1, input2], out) + + x = np.random.uniform(0.1, 1.0, size=(4, 3, 5)).astype(np.float32) + inputs = [x, x] + + expected = model.predict(inputs) + onnx_model = keras2onnx.convert_keras(model, model.name) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) + def test_seq_dynamic_batch_size(self): K.clear_session() data_dim = 4 # input_size From a34c83d70d8c73a07f98704c55aa1dbf366050e1 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sat, 28 Mar 2020 22:31:27 -0400 Subject: [PATCH 08/12] Adding support for Bidirectional LSTM initial states --- keras2onnx/ke2onnx/lstm.py | 46 ++++++++++++++++++++++----------- keras2onnx/ke2onnx/simplernn.py | 40 +++++++++++++++++++++++----- tests/test_layers.py | 17 ++++++------ 3 files changed, 73 insertions(+), 30 deletions(-) diff --git a/keras2onnx/ke2onnx/lstm.py b/keras2onnx/ke2onnx/lstm.py index d67354a2..6b16d148 100644 --- a/keras2onnx/ke2onnx/lstm.py +++ b/keras2onnx/ke2onnx/lstm.py @@ -8,6 +8,7 @@ from collections.abc import Iterable from ..common import cvtfunc, name_func from ..common.onnx_ops import ( + apply_concat, apply_identity, apply_reshape, apply_slice, @@ -109,28 +110,43 @@ def build_parameters(scope, operator, container, bidirectional=False): return tensor_w, tensor_r, tensor_b def build_initial_states(scope, operator, container, bidirectional=False): - """ + """Builds the initial hidden and cell states for the LSTM layer. """ _name = name_func(scope, operator) - initial_h = '' - initial_c = '' + initial_h = simplernn.build_initial_states(scope, operator, container, bidirectional) + + # Determine if the cell states are set + has_c = ( + (len(operator.inputs) > 1 and not bidirectional) or + (len(operator.inputs) > 3 and bidirectional) + ) + if not has_c: + return initial_h, '' + + op = operator.raw_operator + initial_c = _name('initial_c') if bidirectional: - if len(operator.inputs) > 1: - # TODO: Add support for inputing initial states for Bidirectional LSTM - raise NotImplemented("Initial states for Bidirectional LSTM is not yet supported") - else: + forward_layer = op.forward_layer + hidden_size = forward_layer.units + desired_shape = [1, -1, hidden_size] + + # Combine the forward and backward_layers + forward_h = _name('initial_c_forward') + backward_h = _name('initial_c_backward') + apply_reshape(scope, operator.inputs[2].full_name, forward_h, container, desired_shape=desired_shape) + apply_reshape(scope, operator.inputs[4].full_name, backward_h, container, desired_shape=desired_shape) - initial_h = simplernn.build_initial_states(scope, operator, container) + apply_concat(scope, [forward_h, backward_h], initial_c, container) + + else: + hidden_size = operator.raw_operator.units + desired_shape = [1, -1, hidden_size] - if len(operator.inputs) > 1: - # Add a reshape after initial_h, 2d -> 3d - hidden_size = operator.raw_operator.units - input_c = operator.inputs[2].full_name - initial_c = _name('initial_c') - apply_reshape(scope, input_c, initial_c, container, - desired_shape=[1, -1, hidden_size]) + # Add a reshape after initial_c, 2d -> 3d + input_c = operator.inputs[2].full_name + apply_reshape(scope, input_c, initial_c, container, desired_shape=desired_shape) return initial_h, initial_c diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 66e341eb..628dcb2e 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -8,6 +8,7 @@ from ..common import name_func from ..common.onnx_ops import ( apply_cast, + apply_concat, apply_identity, apply_reshape, apply_slice, @@ -170,12 +171,39 @@ def build_initial_states(scope, operator, container, bidirectional=False): if len(operator.inputs) == 1: return '' - # Add a reshape after initial_h, 2d -> 3d - hidden_size = operator.raw_operator.units - input_h = operator.inputs[1].full_name - initial_h = scope.get_unique_variable_name(operator.full_name + '_initial_h') - desired_shape = [2, -1, hidden_size] if bidirectional else [1, -1, hidden_size] - apply_reshape(scope, input_h, initial_h, container, desired_shape=desired_shape) + op = operator.raw_operator + _name = name_func(scope, operator) + + initial_h = _name('initial_h') + + if bidirectional: + forward_layer = op.forward_layer + hidden_size = forward_layer.units + desired_shape = [1, -1, hidden_size] + + # Combine the forward and backward layers + forward_h = _name('initial_h_forward') + backward_h = _name('initial_h_backward') + + # Handle LSTM initial hidden case to enable code reuse + if len(operator.inputs) > 4: + f, b = 1, 3 + else: + f, b = 1, 2 + + apply_reshape(scope, operator.inputs[f].full_name, forward_h, container, desired_shape=desired_shape) + apply_reshape(scope, operator.inputs[b].full_name, backward_h, container, desired_shape=desired_shape) + + apply_concat(scope, [forward_h, backward_h], initial_h, container) + + else: + hidden_size = operator.raw_operator.units + desired_shape = [1, -1, hidden_size] + + # Add a reshape after initial_h, 2d -> 3d + input_h = operator.inputs[1].full_name + apply_reshape(scope, input_h, initial_h, container, desired_shape=desired_shape) + return initial_h diff --git a/tests/test_layers.py b/tests/test_layers.py index 9f06995a..9de02876 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1581,22 +1581,21 @@ def test_Bidirectional_with_initial_states(self): model = Model(input1, states) x = np.random.uniform(0.1, 1.0, size=(4, 3, 5)).astype(np.float32) - #inputs = [x, x] inputs = [x] expected = model.predict(inputs) onnx_model = keras2onnx.convert_keras(model, model.name) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) - #input2 = Input(shape=(None, 5)) - #states = Bidirectional(rnn_class(2, return_state=True))(input1)[1:] - #out = Bidirectional(rnn_class(2, return_sequences=True))(input2, initial_state=states) - #model = Model([input1, input2], out) - #inputs = [x, x] + input2 = Input(shape=(None, 5)) + states = Bidirectional(rnn_class(2, return_state=True))(input1)[1:] + out = Bidirectional(rnn_class(2, return_sequences=True))(input2, initial_state=states) + model = Model([input1, input2], out) + inputs = [x, x] - #expected = model.predict(inputs) - #onnx_model = keras2onnx.convert_keras(model, model.name) - #self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) + expected = model.predict(inputs) + onnx_model = keras2onnx.convert_keras(model, model.name) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) # Bidirectional LSTM with seq_length = None @unittest.skipIf(get_opset_number_from_onnx() < 5, From 9cc644685131631fb29c11a7ce3bde90cdc9f61f Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sun, 29 Mar 2020 20:17:45 -0400 Subject: [PATCH 09/12] Excluding TF2 from tests for RNN initial states --- tests/test_layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_layers.py b/tests/test_layers.py index 9de02876..4aff07ba 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1574,6 +1574,7 @@ def test_Bidirectional_with_bias(self): onnx_model = keras2onnx.convert_keras(model, model.name) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + @unittest.skipIf(is_tf2 and is_tf_keras, 'TODO') def test_Bidirectional_with_initial_states(self): for rnn_class in [SimpleRNN, GRU, LSTM]: input1 = Input(shape=(None, 5)) @@ -1613,6 +1614,7 @@ def test_Bidirectional_seqlen_none(self): expected = model.predict(x) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + @unittest.skipIf(is_tf2, 'TODO') def test_rnn_state_passing(self): for rnn_class in [SimpleRNN, GRU, LSTM]: input1 = Input(shape=(None, 5)) From 7bf2b3ccd1fbfa37476c52f366058680d1ac252b Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sun, 29 Mar 2020 20:56:49 -0400 Subject: [PATCH 10/12] Adjusting tolerance to consider the initial states for test_rnn_state_passing --- tests/test_layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 4aff07ba..b5015151 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1614,7 +1614,6 @@ def test_Bidirectional_seqlen_none(self): expected = model.predict(x) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) - @unittest.skipIf(is_tf2, 'TODO') def test_rnn_state_passing(self): for rnn_class in [SimpleRNN, GRU, LSTM]: input1 = Input(shape=(None, 5)) @@ -1629,7 +1628,7 @@ def test_rnn_state_passing(self): expected = model.predict(inputs) onnx_model = keras2onnx.convert_keras(model, model.name) - self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files, atol=1e-5)) def test_seq_dynamic_batch_size(self): K.clear_session() From 477a446dc88c4679d2c32c20876c6cd6b97ae7b9 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sun, 29 Mar 2020 21:07:12 -0400 Subject: [PATCH 11/12] Reinstating skip for test_rnn_state_passing for TF2 --- tests/test_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_layers.py b/tests/test_layers.py index b5015151..edf447f5 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1614,6 +1614,7 @@ def test_Bidirectional_seqlen_none(self): expected = model.predict(x) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + @unittest.skipIf(is_tf2, 'TODO') def test_rnn_state_passing(self): for rnn_class in [SimpleRNN, GRU, LSTM]: input1 = Input(shape=(None, 5)) From 5b3131bfc419f94600f92ff69b1416510330f85f Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sun, 29 Mar 2020 21:11:51 -0400 Subject: [PATCH 12/12] Adjusting tolerance to consider the initial states for test_Bidirectional_with_initial_states --- tests/test_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index edf447f5..ccb04c4b 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1596,7 +1596,7 @@ def test_Bidirectional_with_initial_states(self): expected = model.predict(inputs) onnx_model = keras2onnx.convert_keras(model, model.name) - self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files)) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs, expected, self.model_files, atol=1e-5)) # Bidirectional LSTM with seq_length = None @unittest.skipIf(get_opset_number_from_onnx() < 5,