Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

add tf.where and logical ops supports. #490

Merged
merged 7 commits into from
May 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions keras2onnx/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,16 +639,15 @@ def convert_tf_conv2d(scope, operator, container):

@converter_func(TYPES.Einsum)
def convert_tf_einsum(scope, operator, container):
if operator.target_opset < 12:
raise ValueError("Einsum op is not supported until opset 12")
oopb = OnnxOperatorBuilder(container, scope)
node = operator.raw_operator
equation_str = node.get_attr('equation').decode("utf-8")
oopb.add_node_with_output("Einsum",
operator.input_full_names,
operator.output_full_names,
name=operator.full_name,
equation=equation_str)
equation=equation_str,
op_version=12)


@converter_func(TYPES.ExpandDims)
Expand Down Expand Up @@ -2041,19 +2040,17 @@ def convert_tf_variable_v2(scope, operator, container):

@converter_func(TYPES.Where)
def convert_tf_where(scope, operator, container):
if operator.target_opset < 9:
raise ValueError("Where op is not supported for opset < 9")
else:
oopb = OnnxOperatorBuilder(container, scope)
node = operator.raw_operator
where_node = oopb.add_node('NonZero',
operator.inputs[0].full_name,
operator.inputs[0].full_name + '_non_zero')
oopb.apply_op_with_output("apply_transpose",
where_node,
operator.output_full_names,
name=operator.full_name + '_transpose',
perm=list(reversed(range(len(node.outputs[0].shape)))))
oopb = OnnxOperatorBuilder(container, scope)
node = operator.raw_operator
where_node = oopb.add_node('NonZero',
operator.inputs[0].full_name,
operator.inputs[0].full_name + '_non_zero',
op_version=9)
oopb.apply_op_with_output("apply_transpose",
where_node,
operator.output_full_names,
name=operator.full_name + '_transpose',
perm=list(reversed(range(len(node.outputs[0].shape)))))


@converter_func(TYPES.ZerosLike)
Expand Down
47 changes: 38 additions & 9 deletions keras2onnx/_tf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnxconverter_common.oopb import OnnxOperatorBuilder
from .funcbook import converter_func
from ._tf_utils import tf_attrs_to_onnx as _to_onnx_attrs
from ._tf_utils import cal_tensor_shape as _cal_tensor_shape


def _random_converter(scope, operator, container):
Expand All @@ -28,24 +29,52 @@ def _random_converter(scope, operator, container):


@converter_func(
"RandomNormal",
'RandomNormal',
'RandomStandardNormal',
"RandomUniform")
'RandomUniform')
def convert_tf_random_standard_normal(scope, operator, container):
_random_converter(scope, operator, container)


@converter_func('Select', 'SelectV2')
def convert_tf_select(scope, operator, container):
tf_op = operator.raw_operator
shape_i0 = _cal_tensor_shape(tf_op.inputs[0])
target_shape = _cal_tensor_shape(tf_op.inputs[1])
if len(target_shape) == 0:
target_shape = _cal_tensor_shape(tf_op.inputs[2])
input0 = operator.input_full_names[0]
with OnnxOperatorBuilder(container, scope).as_default(operator.full_name) as oopb: # type: OnnxOperatorBuilder
if len(shape_i0) == 1 and len(target_shape) > 1:
input0 = oopb.unsqueeze(input0, axes=list(range(len(target_shape)))[1:])
oopb.add_node("Where", [input0] + operator.input_full_names[1:],
outputs=operator.output_full_names,
op_version=9)


@converter_func('LogicalNot', 'LogicalAnd', 'LogicalOr')
def convert_tf_logical_ops(scope, operator, container):
onnx_type = operator.type[len('Logical'):]
oopb = OnnxOperatorBuilder(container, scope)
oopb.add_node(onnx_type,
operator.input_full_names,
name=operator.full_name,
outputs=operator.output_full_names,
op_version=1)


def pass_thru_converter(scope, operator, container):
"""
This converter is to copy the original graph node with its def into a ONNX node format.
"""
tf_op = operator.raw_operator
attrs = _to_onnx_attrs(tf_op)

container.add_node(operator.type,
operator.input_full_names,
operator.output_full_names,
name=operator.full_name,
op_domain='ai.onnx.contrib',
op_version=1,
**attrs)
oopb = OnnxOperatorBuilder(container, scope)
oopb.add_node(operator.type,
operator.input_full_names,
name=operator.full_name,
outputs=operator.output_full_names,
op_domain='ai.onnx.contrib',
op_version=1,
**attrs)
9 changes: 7 additions & 2 deletions keras2onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def convert_keras(model, name=None, doc_string='', target_opset=None,
print(model.summary())

name = name or model.name
target_opset = target_opset or get_maximum_opset_supported()

cvt_default_opset = get_maximum_opset_supported()
if target_opset is None:
target_opset = cvt_default_opset
elif target_opset > cvt_default_opset:
raise RuntimeError(
"The opset {} conversion not support yet, the current maximum opset version supported is {}.".format(
target_opset, cvt_default_opset))
input_names = []
output_names = []
output_dict = {}
Expand Down
9 changes: 6 additions & 3 deletions keras2onnx/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# license information.
###############################################################################
from onnxconverter_common.onnx_ex import make_model_ex
from .common import k2o_logger
from .common import utils, k2o_logger
from .common import OnnxObjectContainer, Variable, InterimContext
from .common.data_types import TensorType, Int64Type, FloatType, StringType
from .funcbook import get_converter
Expand Down Expand Up @@ -227,7 +227,8 @@ def _remove_unused_nodes(nodes, inputs, outputs):
if in_ in output_dict:
node_inputs.append(output_dict[in_])
else:
assert in_ == '' or in_ in input_dict
assert in_ == '' or in_ in input_dict, \
"{} is disconnected, check the parsing log for more details.".format(in_)

return [nd_ for nd_ in nodes if id(nd_) in nodes_to_keep]

Expand Down Expand Up @@ -375,5 +376,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, channel_fir
# Create model
onnx_model = make_model_ex(graph,
container.node_domain_version_pair_sets,
target_opset, doc_string=doc_string)
target_opset, doc_string=doc_string,
producer_name=utils.get_producer(),
domain=utils.get_domain())
return onnx_model
13 changes: 9 additions & 4 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,8 @@ def test_LeakyReLU(advanced_activation_runner):
advanced_activation_runner(layer, data)


@pytest.mark.skipif(get_maximum_opset_supported() < 8,
reason="ThresoldRelu needs ONNX opset 8")
def test_ThresholdedReLU(advanced_activation_runner):
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
layer = advanced_activations.ThresholdedReLU(theta=1.0, input_shape=(data.size,))
Expand Down Expand Up @@ -1646,35 +1648,38 @@ def test_LSTM(runner):
expected = model.predict(data)
assert runner(onnx_model.graph.name, onnx_model, data, expected)


@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)),
reason="keras LSTM does not have time_major attribute")
def test_LSTM_time_major_return_seq_true(runner):
inputs1 = keras.Input(shape=(3, 5))
data = np.random.rand(1, 3, 5).astype(np.float32)
# Transpose input to be time major
input_transposed = tf.transpose(inputs1, perm=[1,0,2])
input_transposed = tf.transpose(inputs1, perm=[1, 0, 2])
lstm1, state_h, state_c = LSTM(units=2, time_major=True, return_state=True,
return_sequences=True)(input_transposed)
lstm1_trans = tf.transpose(lstm1, perm=[1,0,2])
lstm1_trans = tf.transpose(lstm1, perm=[1, 0, 2])
model = keras.Model(inputs=inputs1, outputs=[lstm1_trans, state_h, state_c])
onnx_model = keras2onnx.convert_keras(model, model.name)
expected = model.predict(data)
assert runner(onnx_model.graph.name, onnx_model, data, expected)

@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)) ,

@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)),
reason="keras LSTM does not have time_major attribute")
def test_LSTM_time_major_return_seq_false(runner):
inputs1 = keras.Input(shape=(3, 5))
data = np.random.rand(1, 3, 5).astype(np.float32)
# Transpose input to be time major
input_transposed = tf.transpose(inputs1, perm=[1,0,2])
input_transposed = tf.transpose(inputs1, perm=[1, 0, 2])
lstm1, state_h, state_c = LSTM(units=2, time_major=True, return_state=True,
return_sequences=False)(input_transposed)
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
onnx_model = keras2onnx.convert_keras(model, model.name)
expected = model.predict(data)
assert runner(onnx_model.graph.name, onnx_model, data, expected)


def test_LSTM_with_bias(runner):
inputs1 = keras.Input(shape=(1, 1))
cls = LSTM(units=1, return_state=True, return_sequences=True)
Expand Down
22 changes: 19 additions & 3 deletions tests/test_tf2_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import keras2onnx
import numpy as np
import tensorflow as tf
from keras2onnx.proto import is_tensorflow_older_than

if (not keras2onnx.proto.is_tf_keras) or (not keras2onnx.proto.tfcompat.is_tf2):
pytest.skip("Tensorflow 2.0 only tests.", allow_module_level=True)
Expand Down Expand Up @@ -50,9 +51,9 @@ def call(self, inputs, **kwargs):
return output


class DummyModel(tf.keras.Model):
class SimpleWrapperModel(tf.keras.Model):
def __init__(self, func):
super(DummyModel, self).__init__()
super(SimpleWrapperModel, self).__init__()
self.func = func

def call(self, inputs, **kwargs):
Expand Down Expand Up @@ -88,7 +89,7 @@ def op_func(arg_inputs):
x = x - tf.cast(tf.expand_dims(r, axis=0), tf.float32)
return x

dm = DummyModel(op_func)
dm = SimpleWrapperModel(op_func)
inputs = [tf.random.normal((3, 2, 20)), tf.random.normal((3, 2, 20))]
expected = dm.predict(inputs)
oxml = keras2onnx.convert_keras(dm)
Expand Down Expand Up @@ -186,3 +187,18 @@ def test_auto_encoder(runner):
# The random generator is not same between different engiens.
import onnx
onnx.checker.check_model(oxml)


def test_tf_where(runner):
def _tf_where(input_0):
a = tf.where(True, input_0, [0, 1, 2, 5, 7])
b = tf.where([True], tf.expand_dims(input_0, axis=0), tf.expand_dims([0, 1, 2, 5, 7], axis=0))
c = tf.logical_or(tf.cast(a, tf.bool), tf.cast(b, tf.bool))
return c

swm = SimpleWrapperModel(_tf_where)
const_in = [np.array([2, 4, 6, 8, 10]).astype(np.int32)]
expected = swm(const_in)
swm._set_inputs(const_in)
oxml = keras2onnx.convert_keras(swm)
assert runner('where_test', oxml, const_in, expected)