Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

fixing the conv auto-pads #397

Merged
merged 3 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install git+https://github.com/microsoft/onnxconverter-common
pip install git+https://github.com/onnx/keras-onnx
```
Before running the converter, please notice that tensorflow has to be installed in your python environment,
you can choose **tensorflow** package(CPU version) or **tensorflow-gpu**(GPU version)
you can choose **tensorflow**/**tensorflow-cpu** package(CPU version) or **tensorflow-gpu**(GPU version)

# Notes
Keras2ONNX supports the new Keras subclassing model which was introduced in tensorflow 2.0 since the version **1.6.5**. Some typical subclassing models like [huggingface/transformers](https://github.com/huggingface/transformers) have been converted into ONNX and validated by ONNXRuntime.<br>
Expand Down
3 changes: 2 additions & 1 deletion keras2onnx/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from typing import Union
from onnx import numpy_helper, mapping
from .common.utils import count_dynamic_dim
from .common.onnx_ops import apply_identity, apply_reshape, OnnxOperatorBuilder
from .funcbook import converter_func, set_converters
from .proto import keras
Expand Down Expand Up @@ -492,7 +493,7 @@ def convert_tf_depthwise_conv2d(scope, operator, container):
if node.get_attr('padding') == b'VALID':
attrs['auto_pad'] = 'VALID'
elif node.get_attr('padding') == b'SAME':
if input_shape.count(None) > 1:
if count_dynamic_dim(input_shape) > 1:
attrs['auto_pad'] = 'SAME_UPPER'
else:
attrs['auto_pad'] = 'NOTSET'
Expand Down
8 changes: 8 additions & 0 deletions keras2onnx/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def get_default_batch_size():
return 'N'


def count_dynamic_dim(shape):
num = 0
for s_ in shape:
if isinstance(s_, int) and s_ >= 0:
num += 1
return len(shape) - num


def get_producer():
"""
Internal helper function to return the producer
Expand Down
3 changes: 2 additions & 1 deletion keras2onnx/ke2onnx/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .activation import activation_map
from ..proto import keras
from ..proto import onnx_proto
from ..common.utils import count_dynamic_dim
from ..common.onnx_ops import (apply_identity, apply_pad, apply_softmax,
apply_transpose, apply_mul, apply_sigmoid)

Expand Down Expand Up @@ -144,7 +145,7 @@ def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, in
if op.padding == 'valid':
attrs['auto_pad'] = 'VALID'
elif op.padding == 'same':
if input_shape.count(None) > 1:
if count_dynamic_dim(input_shape) > 1:
if is_transpose:
attrs['auto_pad'] = 'SAME_LOWER' # the controversial def in onnx spec.
else:
Expand Down