diff --git a/keras2onnx/ke2onnx/gru.py b/keras2onnx/ke2onnx/gru.py index 82c5893d..ba5c8938 100644 --- a/keras2onnx/ke2onnx/gru.py +++ b/keras2onnx/ke2onnx/gru.py @@ -50,7 +50,13 @@ def convert_keras_gru(scope, operator, container): gru_input_names.append('') # sequence lens - gru_input_names.append('') + uses_masking_layer = len(operator.input_masks) == 1 + if uses_masking_layer: + # Mask using sequence_lens input + sequence_lengths = scope.get_unique_variable_name(operator.full_name + '_seq_lens') + gru_input_names.append(sequence_lengths) + else: + gru_input_names.append('') # inital_h if len(operator.inputs) == 1: gru_input_names.append('') @@ -88,6 +94,12 @@ def convert_keras_gru(scope, operator, container): gru_h_name = scope.get_unique_variable_name('gru_h') gru_output_names = [gru_y_name, gru_h_name] oopb = OnnxOperatorBuilder(container, scope) + + if uses_masking_layer: + mask_cast = oopb.apply_cast(operator.input_masks[0].full_name, to=oopb.int32, name=operator.full_name + '_mask_cast') + oopb.add_node_with_output('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1], name=operator.full_name + '_mask_sum') + + oopb.apply_op_with_output('apply_gru', gru_input_names, gru_output_names, diff --git a/keras2onnx/ke2onnx/lstm.py b/keras2onnx/ke2onnx/lstm.py index 871e28fc..094d1f1f 100644 --- a/keras2onnx/ke2onnx/lstm.py +++ b/keras2onnx/ke2onnx/lstm.py @@ -92,7 +92,13 @@ def convert_keras_lstm(scope, operator, container): lstm_input_names.append('') # sequence_lens - lstm_input_names.append('') + uses_masking_layer = len(operator.input_masks) == 1 + if uses_masking_layer: + # Mask using sequence_lens input + sequence_lengths = scope.get_unique_variable_name(operator.full_name + '_seq_lens') + lstm_input_names.append(sequence_lengths) + else: + lstm_input_names.append('') # inital_h if len(operator.inputs) <= 1: lstm_input_names.append('') @@ -149,6 +155,11 @@ def convert_keras_lstm(scope, operator, container): lstm_output_names.append(lstm_c_name) oopb = OnnxOperatorBuilder(container, scope) + + if uses_masking_layer: + mask_cast = oopb.apply_cast(operator.input_masks[0].full_name, to=oopb.int32, name=operator.full_name + '_mask_cast') + oopb.add_node_with_output('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1], name=operator.full_name + '_mask_sum') + oopb.apply_op_with_output('apply_lstm', lstm_input_names, lstm_output_names, diff --git a/keras2onnx/ke2onnx/main.py b/keras2onnx/ke2onnx/main.py index 1ef9c870..65a8cf2f 100644 --- a/keras2onnx/ke2onnx/main.py +++ b/keras2onnx/ke2onnx/main.py @@ -79,7 +79,7 @@ def _apply_not_equal(oopb, target_opset, operator): k2o_logger().warning("On converting a model with opset < 11, " + "the masking layer result may be incorrect if the model input is in range (0, 1.0).") equal_input_0 = oopb.add_node('Cast', [operator.inputs[0].full_name], - operator.full_name + '_input_cast', to=6) + operator.full_name + '_input_cast', to=oopb.int32) equal_out = oopb.add_node('Equal', [equal_input_0, np.array([operator.mask_value], dtype='int32')], operator.full_name + 'mask') not_o = oopb.add_node('Not', equal_out, diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 90b63385..a5c2f92d 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -48,7 +48,13 @@ def convert_keras_simple_rnn(scope, operator, container): rnn_input_names.append('') # sequence_lens is not able to be converted from input_length - rnn_input_names.append('') + uses_masking_layer = len(operator.input_masks) == 1 + if uses_masking_layer: + # Mask using sequence_lens input + sequence_lengths = scope.get_unique_variable_name(operator.full_name + '_seq_lens') + rnn_input_names.append(sequence_lengths) + else: + rnn_input_names.append('') # inital_h: none if len(operator.inputs) == 1: rnn_input_names.append('') @@ -77,6 +83,11 @@ def convert_keras_simple_rnn(scope, operator, container): rnn_output_names.append(rnn_y_name) rnn_output_names.append(rnn_h_name) oopb = OnnxOperatorBuilder(container, scope) + + if uses_masking_layer: + mask_cast = oopb.apply_cast(operator.input_masks[0].full_name, to=oopb.int32, name=operator.full_name + '_mask_cast') + oopb.add_node_with_output('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1], name=operator.full_name + '_mask_sum') + oopb.apply_op_with_output('apply_rnn', rnn_input_names, rnn_output_names, diff --git a/tests/test_layers.py b/tests/test_layers.py index b49da96f..c363bca1 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1796,6 +1796,36 @@ def test_masking(self): expected = model.predict(x) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + @unittest.skipIf(is_tf2 and is_tf_keras, 'TODO') + def test_masking_bias(self): + for rnn_class in [LSTM, GRU, SimpleRNN]: + + timesteps, features = (3, 5) + model = Sequential([ + keras.layers.Masking(mask_value=0., input_shape=(timesteps, features)), + rnn_class(8, return_state=False, return_sequences=False, use_bias=True, name='rnn') + ]) + + x = np.random.uniform(100, 999, size=(2, 3, 5)).astype(np.float32) + # Fill one of the entries with all zeros except the first timestep + x[1, 1:, :] = 0 + + # Test with the default bias + expected = model.predict(x) + onnx_model = keras2onnx.convert_keras(model, model.name) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + + # Set bias values to random floats + rnn_layer = model.get_layer('rnn') + weights = rnn_layer.get_weights() + weights[2] = np.random.uniform(size=weights[2].shape) + rnn_layer.set_weights(weights) + + # Test with random bias + expected = model.predict(x) + onnx_model = keras2onnx.convert_keras(model, model.name) + self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files)) + @unittest.skipIf(is_tf2 and is_tf_keras, 'TODO') def test_masking_value(self): timesteps, features = (3, 5)