Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Support initial states for Bidirectional RNN #417

Merged
merged 14 commits into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras2onnx/ke2onnx/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ def convert_keras_gru(scope, operator, container, bidirectional=False):
**attrs)

simplernn.build_output(scope, operator, container, output_names, bidirectional)
simplernn.build_output_states(scope, operator, container, output_names, bidirectional)
107 changes: 72 additions & 35 deletions keras2onnx/ke2onnx/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from collections.abc import Iterable
from ..common import cvtfunc, name_func
from ..common.onnx_ops import (
apply_concat,
apply_identity,
apply_reshape,
apply_slice,
apply_split,
apply_squeeze,
apply_transpose,
OnnxOperatorBuilder
Expand All @@ -22,18 +24,6 @@
TensorProto = onnx_proto.TensorProto


def check_sequence_lengths(operator, container):
"""Raises an exception if the shape is expected to be static, but the sequence lenghts
are not provided. This only applies to opsets below 9.
"""
op = operator.raw_operator

_, seq_length, input_size = simplernn.extract_input_shape(op)
is_static_shape = seq_length is not None
if not is_static_shape and container.target_opset < 9:
raise ValueError('None seq_length is not supported in opset ' + str(container.target_opset))


def convert_ifco_to_iofc(tensor_ifco):
"""Returns a tensor in input (i), output (o), forget (f), cell (c) ordering. The
Keras ordering is ifco, while the ONNX ordering is iofc.
Expand Down Expand Up @@ -120,28 +110,43 @@ def build_parameters(scope, operator, container, bidirectional=False):
return tensor_w, tensor_r, tensor_b

def build_initial_states(scope, operator, container, bidirectional=False):
"""
"""Builds the initial hidden and cell states for the LSTM layer.
"""
_name = name_func(scope, operator)

initial_h = ''
initial_c = ''
initial_h = simplernn.build_initial_states(scope, operator, container, bidirectional)

# Determine if the cell states are set
has_c = (
(len(operator.inputs) > 1 and not bidirectional) or
(len(operator.inputs) > 3 and bidirectional)
)
if not has_c:
return initial_h, ''

op = operator.raw_operator
initial_c = _name('initial_c')

if bidirectional:
if len(operator.inputs) > 1:
# TODO: Add support for inputing initial states for Bidirectional LSTM
raise NotImplemented("Initial states for Bidirectional LSTM is not yet supported")
else:
forward_layer = op.forward_layer
hidden_size = forward_layer.units
desired_shape = [1, -1, hidden_size]

# Combine the forward and backward_layers
forward_h = _name('initial_c_forward')
backward_h = _name('initial_c_backward')
apply_reshape(scope, operator.inputs[2].full_name, forward_h, container, desired_shape=desired_shape)
apply_reshape(scope, operator.inputs[4].full_name, backward_h, container, desired_shape=desired_shape)

initial_h = simplernn.build_initial_states(scope, operator, container)
apply_concat(scope, [forward_h, backward_h], initial_c, container)

else:
hidden_size = operator.raw_operator.units
desired_shape = [1, -1, hidden_size]

if len(operator.inputs) > 1:
# Add a reshape after initial_h, 2d -> 3d
hidden_size = operator.raw_operator.units
input_c = operator.inputs[2].full_name
initial_c = _name('initial_c')
apply_reshape(scope, input_c, initial_c, container,
desired_shape=[1, -1, hidden_size])
# Add a reshape after initial_c, 2d -> 3d
input_c = operator.inputs[2].full_name
apply_reshape(scope, input_c, initial_c, container, desired_shape=desired_shape)

return initial_h, initial_c

Expand Down Expand Up @@ -179,7 +184,7 @@ def build_attributes(scope, operator, container, bidirectional=False):
return attrs

def build_output(scope, operator, container, output_names, bidirectional=False):
"""
"""Builds the output operators for the LSTM layer.
"""
if bidirectional:
return simplernn.build_output(scope, operator, container, output_names[:-1], bidirectional)
Expand All @@ -189,7 +194,6 @@ def build_output(scope, operator, container, output_names, bidirectional=False):
op = operator.raw_operator
hidden_size = op.units
output_seq = op.return_sequences
output_state = op.return_state
_, seq_length, input_size = simplernn.extract_input_shape(op)
is_static_shape = seq_length is not None

Expand Down Expand Up @@ -229,10 +233,44 @@ def build_output(scope, operator, container, output_names, bidirectional=False):
else:
apply_reshape(scope, lstm_h, output_name, container, desired_shape=[-1, hidden_size])

if output_state:
# Output hidden and cell states
apply_reshape(scope, lstm_h, operator.outputs[1].full_name, container, desired_shape=[-1, hidden_size])
apply_reshape(scope, lstm_c, operator.outputs[2].full_name, container, desired_shape=[-1, hidden_size])

def build_output_states(scope, operator, container, output_names, bidirectional=False):
"""Builds the output hidden states for the LSTM layer.
"""
_, lstm_h, lstm_c = output_names
op = operator.raw_operator

if bidirectional:
forward_layer = op.forward_layer
output_state = forward_layer.return_state

if not output_state:
return

# Split lstm_h and lstm_c into forward and backward components
squeeze_names = []
output_names = [o.full_name for o in operator.outputs[1:]]
name_map = {lstm_h: output_names[::2], lstm_c: output_names[1::2]}

for state_name, outputs in name_map.items():
split_names = ['{}_{}'.format(state_name, d) for d in ('forward', 'backward')]

apply_split(scope, state_name, split_names, container)
squeeze_names.extend(list(zip(split_names, outputs)))

for split_name, output_name in squeeze_names:
apply_squeeze(scope, split_name, output_name, container)

else:
output_state = op.return_state

if not output_state:
return

output_h = operator.outputs[1].full_name
output_c = operator.outputs[2].full_name
apply_squeeze(scope, lstm_h, output_h, container)
apply_squeeze(scope, lstm_c, output_c, container)


def _calculate_keras_lstm_output_shapes(operator):
Expand All @@ -254,8 +292,6 @@ def convert_keras_lstm(scope, operator, container, bidirectional=False):
else:
output_seq = op.return_sequences

check_sequence_lengths(operator, container)

# Inputs
lstm_x = _name('X')
tensor_w, tensor_r, tensor_b = build_parameters(scope, operator, container, bidirectional)
Expand Down Expand Up @@ -292,3 +328,4 @@ def convert_keras_lstm(scope, operator, container, bidirectional=False):
**attrs)

build_output(scope, operator, container, output_names, bidirectional)
build_output_states(scope, operator, container, output_names, bidirectional)
127 changes: 97 additions & 30 deletions keras2onnx/ke2onnx/simplernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..common import name_func
from ..common.onnx_ops import (
apply_cast,
apply_concat,
apply_identity,
apply_reshape,
apply_slice,
Expand Down Expand Up @@ -170,12 +171,39 @@ def build_initial_states(scope, operator, container, bidirectional=False):
if len(operator.inputs) == 1:
return ''

# Add a reshape after initial_h, 2d -> 3d
hidden_size = operator.raw_operator.units
input_h = operator.inputs[1].full_name
initial_h = scope.get_unique_variable_name(operator.full_name + '_initial_h')
desired_shape = [2, -1, hidden_size] if bidirectional else [1, -1, hidden_size]
apply_reshape(scope, input_h, initial_h, container, desired_shape=desired_shape)
op = operator.raw_operator
_name = name_func(scope, operator)

initial_h = _name('initial_h')

if bidirectional:
forward_layer = op.forward_layer
hidden_size = forward_layer.units
desired_shape = [1, -1, hidden_size]

# Combine the forward and backward layers
forward_h = _name('initial_h_forward')
backward_h = _name('initial_h_backward')

# Handle LSTM initial hidden case to enable code reuse
if len(operator.inputs) > 4:
f, b = 1, 3
else:
f, b = 1, 2

apply_reshape(scope, operator.inputs[f].full_name, forward_h, container, desired_shape=desired_shape)
apply_reshape(scope, operator.inputs[b].full_name, backward_h, container, desired_shape=desired_shape)

apply_concat(scope, [forward_h, backward_h], initial_h, container)

else:
hidden_size = operator.raw_operator.units
desired_shape = [1, -1, hidden_size]

# Add a reshape after initial_h, 2d -> 3d
input_h = operator.inputs[1].full_name
apply_reshape(scope, input_h, initial_h, container, desired_shape=desired_shape)

return initial_h


Expand Down Expand Up @@ -220,28 +248,25 @@ def build_output(scope, operator, container, output_names, bidirectional=False):

op = operator.raw_operator
_, seq_length, input_size = extract_input_shape(op)
is_static_shape = seq_length is not None

_name = name_func(scope, operator)

oopb = OnnxOperatorBuilder(container, scope)

# Define seq_dim
if not is_static_shape:
input_name = operator.inputs[0].full_name
input_shape_tensor = oopb.add_node('Shape', [input_name], input_name + '_input_shape_tensor')

seq_dim = input_name + '_seq_dim'
apply_slice(scope, input_shape_tensor, seq_dim, container, [1], [2], axes=[0])

if bidirectional:
forward_layer = op.forward_layer

hidden_size = forward_layer.units
is_static_shape = seq_length is not None
output_seq = forward_layer.return_sequences
output_state = forward_layer.return_state
if output_state:
raise ValueError('Keras Bidirectional cannot return hidden and cell states')

oopb = OnnxOperatorBuilder(container, scope)

# Define seq_dim
if not is_static_shape:
input_name = operator.inputs[0].full_name
input_shape_tensor = oopb.add_node('Shape', [input_name], input_name + '_input_shape_tensor')

seq_dim = input_name + '_seq_dim'
apply_slice(scope, input_shape_tensor, seq_dim, container, [1], [2], axes=[0])

merge_concat = False
if hasattr(op, 'merge_mode'):
Expand Down Expand Up @@ -377,23 +402,64 @@ def build_output(scope, operator, container, output_names, bidirectional=False):
else:
hidden_size = op.units
output_seq = op.return_sequences
output_state = op.return_state

output_name = operator.outputs[0].full_name
tranposed_y = scope.get_unique_variable_name(operator.full_name + '_y_transposed')
transposed_y = scope.get_unique_variable_name(operator.full_name + '_y_transposed')

# Determine the source, transpose permutation, and output shape
if output_seq:
perm = [1, 0, 2] if container.target_opset <= 5 else [2, 0, 1, 3]
apply_transpose(scope, rnn_y, tranposed_y, container, perm=perm)
apply_reshape(scope, tranposed_y, output_name, container,
desired_shape=[-1, seq_length, hidden_size])
source = rnn_y
perm = [2, 0, 1, 3]
if is_static_shape:
desired_shape = [-1, seq_length, hidden_size]
elif container.target_opset < 5:
# Before Reshape-5 you can not take the sequence dimension in as an input
raise ValueError('At least opset 5 is required for output sequences')
else:
# Dynamically determine the output shape based on the sequence dimension
shape_values = [
('_a', oopb.int64, np.array([-1], dtype='int64')),
seq_dim,
('_b', oopb.int64, np.array([hidden_size], dtype='int64')),
]
shape_name = _name('_output_seq_shape')
desired_shape = oopb.add_node('Concat', shape_values, shape_name, axis=0)
else:
# Here we ingore ONNX RNN's first output because it's useless.
apply_transpose(scope, rnn_h, tranposed_y, container, perm=[1, 0, 2])
apply_reshape(scope, tranposed_y, output_name, container, desired_shape=[-1, hidden_size])
# Use the last hidden states directly
source = rnn_h
perm = [1, 0, 2]
desired_shape = [-1, hidden_size]

apply_transpose(scope, source, transposed_y, container, perm=perm)
apply_reshape(scope, transposed_y, output_name, container, desired_shape=desired_shape)


def build_output_states(scope, operator, container, output_names, bidirectional=False):
"""Builds the output hidden states for the RNN layer.
"""
_, rnn_h = output_names
op = operator.raw_operator

if bidirectional:
forward_layer = op.forward_layer
output_state = forward_layer.return_state

if output_state:
# Split rnn_h into forward and backward directions
output_names = [o.full_name for o in operator.outputs[1:]]
split_names = ['{}_{}'.format(rnn_h, d) for d in ('forward', 'backward')]

apply_split(scope, rnn_h, split_names, container)

for split_name, output_name in zip(split_names, output_names):
apply_squeeze(scope, split_name, output_name, container)

else:
output_state = op.return_state

if output_state:
apply_reshape(scope, rnn_h, operator.outputs[1].full_name, container, desired_shape=[-1, hidden_size])
output_h = operator.outputs[1].full_name
apply_squeeze(scope, rnn_h, output_h, container)


def convert_keras_simple_rnn(scope, operator, container, bidirectional=False):
Expand Down Expand Up @@ -440,3 +506,4 @@ def convert_keras_simple_rnn(scope, operator, container, bidirectional=False):
**attrs)

build_output(scope, operator, container, output_names, bidirectional)
build_output_states(scope, operator, container, output_names, bidirectional)
12 changes: 12 additions & 0 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
list_input_tensors, list_input_mask, list_output_mask,
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)

ALLOWED_SHARED_KERAS_TYPES = {
keras.layers.embeddings.Embedding,
}

def _find_node(nodes, name):
try:
Expand Down Expand Up @@ -570,6 +573,7 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names)
for n_ in model_outputs:
q_overall.put_nowait(n_)

visited_layers = set()
visited = set() # since the output could be shared among the successor nodes.
inference_nodeset = _build_inference_nodeset(graph, model_outputs)
keras_nodeset = _build_keras_nodeset(inference_nodeset, keras_node_dict)
Expand All @@ -581,6 +585,14 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names)
nodes = []
layer_key_, model_ = _parse_nodes(graph, inference_nodeset, input_nodes, keras_node_dict, keras_nodeset,
node, nodes, varset, visited, q_overall)

# Only parse Keras layers once (allow certain shared classes)
if layer_key_ in visited_layers:
jiafatom marked this conversation as resolved.
Show resolved Hide resolved
if not type(layer_key_) in ALLOWED_SHARED_KERAS_TYPES:
continue
else:
visited_layers.add(layer_key_)

if not nodes: # already processed by the _parse_nodes
continue

Expand Down
Loading