Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[wip] [v1.x] Onnx export support for slicechannel and box_nms (#19846)
Browse files Browse the repository at this point in the history
* slicechannel

* add warning to box_nms an unblock when id_idex!=-1

* Update _op_translations.py

* add 0,1
  • Loading branch information
Zha0q1 authored Feb 6, 2021
1 parent 12b297e commit da887f1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
57 changes: 28 additions & 29 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,44 +1778,39 @@ def convert_slice_axis(node, **kwargs):
return nodes


@mx_op.register("SliceChannel")
@mx_op.register('SliceChannel')
def convert_slice_channel(node, **kwargs):
"""Map MXNet's SliceChannel operator attributes to onnx's Squeeze or Split
operator based on squeeze_axis attribute
and return the created node.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

opset_version = kwargs['opset_version']
if opset_version < 11:
raise AttributeError('ONNX opset 11 or greater is required to export this operator')
num_outputs = int(attrs.get('num_outputs'))
axis = int(attrs.get('axis', 1))
squeeze_axis = attrs.get('squeeze_axis', 'False')

num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True'])
create_tensor([axis], name+'_axis', kwargs['initializer'])
create_tensor([axis+1], name+'axis_p1', kwargs['initializer'])

if squeeze_axis == 1 and num_outputs == 1:
node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
axes=[axis],
name=name,
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
node = onnx.helper.make_node(
"Split",
input_nodes,
[name+str(i) for i in range(num_outputs)],
axis=axis,
name=name
)
return [node]
nodes = []
if squeeze_axis in ['True', '1']:
nodes += [
make_node('Split', [input_nodes[0]], [name+str(i)+'_' for i in range(num_outputs)],
axis=axis)
]
for i in range(num_outputs):
nodes += [
make_node('Squeeze', [name+str(i)+'_'], [name+str(i)], axes=[axis])
]
else:
raise NotImplementedError("SliceChannel operator with num_outputs>1 and"
"squeeze_axis true is not implemented.")
nodes += [
make_node('Split', [input_nodes[0]], [name+str(i) for i in range(num_outputs)],
axis=axis)
]

return nodes

@mx_op.register("expand_dims")
def convert_expand_dims(node, **kwargs):
Expand Down Expand Up @@ -3089,6 +3084,7 @@ def convert_contrib_box_nms(node, **kwargs):
coord_start = int(attrs.get('coord_start', '2'))
score_index = int(attrs.get('score_index', '1'))
id_index = int(attrs.get('id_index', '-1'))
force_suppress = attrs.get('force_suppress', 'True')
background_id = int(attrs.get('background_id', '-1'))
in_format = attrs.get('in_format', 'corner')
out_format = attrs.get('out_format', 'corner')
Expand All @@ -3101,8 +3097,11 @@ def convert_contrib_box_nms(node, **kwargs):
if background_id != -1:
raise NotImplementedError('box_nms does not currently support background_id != -1')

if id_index != -1:
raise NotImplementedError('box_nms does not currently support id_index != -1')
if id_index != -1 or force_suppress == 'False':
logging.warning('box_nms: id_idex != -1 or/and force_suppress == False detected. '
'However, due to ONNX limitations, boxes of different categories will NOT '
'be exempted from suppression. This might lead to different behavior than '
'native MXNet')

nodes = [
create_tensor([coord_start], name+'_cs', kwargs['initializer']),
Expand Down
13 changes: 13 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,19 @@ def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group,
op_export_test('convolution', M, inputs, tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
@pytest.mark.parametrize('num_outputs', [1, 3, 9])
@pytest.mark.parametrize('axis', [1, 2, -1, -2])
@pytest.mark.parametrize('squeeze_axis', [True, False, 0, 1])
def test_onnx_export_slice_channel(tmp_path, dtype, num_outputs, axis, squeeze_axis):
shape = (3, 9, 18)
if squeeze_axis and shape[axis] != num_outputs:
return
M = def_model('SliceChannel', num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
x = mx.random.uniform(0, 1, shape, dtype=dtype)
op_export_test('slice_channel', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize('momentum', [0.9, 0.5, 0.1])
def test_onnx_export_batchnorm(tmp_path, dtype, momentum):
Expand Down

0 comments on commit da887f1

Please sign in to comment.