Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX Supoort for MXNet _contrib_BilinearResize2D op (#19733)
Browse files Browse the repository at this point in the history
* initial

* resize

* _contrib_BilinearResize2D

* restore sanity
  • Loading branch information
Zha0q1 authored Jan 10, 2021
1 parent c9f111f commit 209a789
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
57 changes: 57 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2698,6 +2698,63 @@ def convert_arange_like(node, **kwargs):

return nodes


@mx_op.register("_contrib_BilinearResize2D")
def convert_contrib_BilinearResize2D(node, **kwargs):
"""Map MXNet's contrib_BilinearResize2D operator attributes to onnx.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

opset_version = kwargs['opset_version']
if opset_version < 11:
raise AttributeError("ONNX opset 11 or greater is required to export this operator")

height = int(attrs.get('height', 0))
width = int(attrs.get('width', 0))

scale_height = float(attrs.get('scale_height', 0))
scale_width = float(attrs.get('scale_width', 0))

if height * width == 0 and scale_height * scale_width == 0:
raise AttributeError('height, width or scale_height, scale_width cannot be 0')

mode = attrs.get('mode', 'size')
if mode != 'size':
raise NotImplementedError('contrib_BilinearResize2D with mode other than "size" is \
not supported')

nodes = [
create_tensor([], name+'_roi', kwargs['initializer'], dtype='float32'),
]

if scale_height == 0:
nodes += [
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([2], name+'_2', kwargs['initializer']),
create_tensor([height, width], name+'_h_w', kwargs['initializer'], dtype='int64'),
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Slice', [name+'_shape', name+'_0', name+'_2'], [name+'_shape_01']),
make_node('Concat', [name+'_shape_01', name+'_h_w'], [name+'_new_shape'], axis=0),
make_node('Cast', [name+'_shape'], [name+'_shape_f'], to=int(TensorProto.FLOAT)),
make_node('Cast', [name+'_new_shape'], [name+'_new_shape_f'],
to=int(TensorProto.FLOAT)),
make_node('Div', [name+'_new_shape_f', name+'_shape_f'], [name+'_scales']),
make_node('Resize', [input_nodes[0], name+'_roi', name+'_scales'], [name],
mode='linear', coordinate_transformation_mode='align_corners', name=name)
]
else:
nodes += [
create_tensor([1, 1, scale_height, scale_width], name+'_scales', kwargs['initializer'],
dtype='float32'),
make_node('Resize', [input_nodes[0], name+'_roi', name+'_scales'], [name],
mode='linear', coordinate_transformation_mode='align_corners', name=name)
]

return nodes


@mx_op.register("_arange")
def convert_arange(node, **kwargs):
"""Map MXNet's arange operator attributes to onnx's Range operator.
Expand Down
19 changes: 19 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,22 @@ def test_onnx_export_softmax(tmp_path, dtype):
M4 = def_model('softmax', use_length=True, axis=1)
l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int)
op_export_test('softmax_4', M4, [x, l4], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('params', [{'height': 7, 'width': 13},
{'height': 10, 'width': 16},
{'height': 3, 'width': 5},
{'height': 2, 'width': 4},
{'scale_height': 3, 'scale_width': 2},
{'scale_height': 1.7, 'scale_width': 2.3},
{'scale_height': 0.5, 'scale_width': 0.6},
{'scale_height': 0.8, 'scale_width': 0.1},
{'scale_height': 2.5, 'scale_width': 0.5},
{'scale_height': 3, 'scale_width': 0.00001},
])
def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params):
x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8))
M = def_model('contrib.BilinearResize2D', **params)
op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)

0 comments on commit 209a789

Please sign in to comment.