Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Onnx Support for Transformer #20048

Merged
merged 7 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 104 additions & 5 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,14 +2064,80 @@ def convert_broadcast_lesser(node, **kwargs):
"""Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
and return the created node.
"""
return create_basic_op_node('Less', node, kwargs)
from onnx.helper import make_node, make_tensor
name, input_nodes, _ = get_inputs(node, kwargs)
input_dtypes = get_input_dtypes(node, kwargs)

dtype = input_dtypes[0]
dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]

nodes = [
make_node('Less', [input_nodes[0], input_nodes[1]], [name+'_lt']),
make_node('Cast', [name+'_lt'], [name], to=dtype_t)
]

return nodes


@mx_op.register("broadcast_lesser_equal")
def convert_broadcast_lesser(node, **kwargs):
"""Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
and return the created node.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, _ = get_inputs(node, kwargs)
input_dtypes = get_input_dtypes(node, kwargs)

dtype = input_dtypes[0]
dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]

nodes = [
make_node('LessOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']),
make_node('Cast', [name+'_lt'], [name], to=dtype_t)
]

return nodes


@mx_op.register("broadcast_greater_equal")
def convert_broadcast_lesser(node, **kwargs):
"""Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
and return the created node.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, _ = get_inputs(node, kwargs)
input_dtypes = get_input_dtypes(node, kwargs)

dtype = input_dtypes[0]
dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]

nodes = [
make_node('GreaterOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']),
make_node('Cast', [name+'_lt'], [name], to=dtype_t)
]

return nodes


@mx_op.register("broadcast_greater")
def convert_broadcast_greater(node, **kwargs):
"""Map MXNet's broadcast_greater operator attributes to onnx's Greater operator
and return the created node.
"""
return create_basic_op_node('Greater', node, kwargs)
from onnx.helper import make_node, make_tensor
name, input_nodes, _ = get_inputs(node, kwargs)
input_dtypes = get_input_dtypes(node, kwargs)

dtype = input_dtypes[0]
dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]

nodes = [
make_node('Greater', [input_nodes[0], input_nodes[1]], [name+'_gt']),
make_node('Cast', [name+'_gt'], [name], to=dtype_t)
]

return nodes


@mx_op.register("broadcast_equal")
def convert_broadcast_equal(node, **kwargs):
Expand Down Expand Up @@ -2498,7 +2564,6 @@ def convert_layer_norm(node, **kwargs):
axes = int(attrs.get('axis', -1))
eps = attrs.get('eps', 9.99999975e-06)


create_tensor([axes], name+"_axes", kwargs["initializer"])
create_tensor([axes+1], name+"_axes+1", kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
Expand All @@ -2519,7 +2584,11 @@ def convert_layer_norm(node, **kwargs):
if axes == -1:
nodes += [
make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name=name)
# make_node("Add", [name+"_mul0_out", input_nodes[2]], [name])
# the Add operator triggers a weird NaN issue in onnxruntime
# a workaround is to use Neg + Sub
make_node('Neg', [input_nodes[2]], [name+'_neg']),
make_node("Sub", [name+"_mul0_out", name+'_neg'], [name])
]
else:
nodes += [
Expand Down Expand Up @@ -4301,7 +4370,7 @@ def convert_RNN(node, **kwargs):

@mx_op.register('_rnn_param_concat')
def convert_rnn_param_concat(node, **kwargs):
"""Map MXNets _rnn_param_concat operator
"""Map MXNet's _rnn_param_concat operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)
Expand All @@ -4313,3 +4382,33 @@ def convert_rnn_param_concat(node, **kwargs):
]

return nodes


@mx_op.register('_contrib_div_sqrt_dim')
def convert_contrib_div_sqrt_dim(node, **kwargs):
"""Map MXNet's _contrib_div_sqrt_dim operator
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, _ = get_inputs(node, kwargs)
input_dtypes = get_input_dtypes(node, kwargs)

dtype = input_dtypes[0]
dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]

create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([1], name+'_1_f', kwargs['initializer'], dtype=dtype)
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_dim', name+'_1'], [name+'_dim_m1']),
make_node('Slice', [name+'_shape', name+'_dim_m1', name+'_dim', name+'_0'], [name+'_c_']),
make_node('Cast', [name+'_c_'], [name+'_c'], to=dtype_t),
make_node('Sqrt', [name+'_c'], [name+'_c_sqrt']),
make_node('Div', [name+'_1_f', name+'_c_sqrt'], [name+'_1_over_c_sqrt']),
make_node('Mul', [input_nodes[0], name+'_1_over_c_sqrt'], [name])
]

return nodes

167 changes: 167 additions & 0 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):

finally:
shutil.rmtree(tmp_path)


@with_seed()
@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'WMT2014'
ctx = mx.cpu(0)
model, _, _ = nlp.model.get_model(
name=model_name,
ctx=ctx,
pretrained=True,
dataset_name=dataset)

model.hybridize(static_alloc=False)

batch = 7
seq_length = 16
C_in = 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are C_in and C_out? Should we also test when C_in != C_out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can refer to this file https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py for C_in and C_out. Those are defined in the pretrained model thus we need to set it the same as in the pretrained model

C_out = 512
src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does not src need to be int type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's float in the original mxnet model too. This should not matter I think because the operator will apply ceiling/flooring

step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')

encoder_outputs, encoder_additional_outputs = model.encode(src,
valid_length=src_valid_length)

decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)

step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)

# skip export of 'decoder' as it's used for training only
for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
'tgt_proj']:

prefix = "%s/%s" %(tmp_path, component)
component = getattr(model, component)
component.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix

def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix
return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types,
onnx_file, **kwargs)

def onnx_runtime_predict(onnx_file, onnx_inputs):
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
for i in range(len(onnx_inputs)))
return session.run(None, input_dict)

def verify_encoder():
inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
pred = model.encoder(inputs, valid_length=valid_length)

prefix = "%s/encoder" %tmp_path
input_shapes = [(batch, seq_length, C_in), (batch,)]
input_types = [np.float32, np.float32]
onnx_file = export_to_onnx(prefix, input_shapes, input_types)
onnx_inputs = [inputs, valid_length]
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)

assert_almost_equal(pred[0], pred_onx[0])

def verify_src_embed():
src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
pred = model.src_embed(src)

prefix = "%s/src_embed" %tmp_path
input_shapes = [(batch, seq_length)]
input_types = [np.float32]
onnx_file = export_to_onnx(prefix, input_shapes, input_types)
onnx_inputs = [src]
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)

assert_almost_equal(pred, pred_onx[0])

def verify_tgt_embed():
tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
pred = model.tgt_embed(tgt)

prefix = "%s/tgt_embed" %tmp_path
input_shapes = [(batch, seq_length)]
input_types = [np.float32]
onnx_file = export_to_onnx(prefix, input_shapes, input_types)
onnx_inputs = [tgt]
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)

assert_almost_equal(pred, pred_onx[0])

def verify_tgt_proj():
decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
dtype='float32')
pred = model.tgt_proj(decoder_out)

prefix = "%s/tgt_proj" %tmp_path
input_shapes = [(batch, seq_length, C_out)]
input_types = [np.float32]
onnx_file = export_to_onnx(prefix, input_shapes, input_types)
onnx_inputs = [decoder_out]
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)

assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)

def verify_one_step_ahead_decoder():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxjscience would you help take a quick look at this func thanks!

prefix = "%s/one_step_ahead_decoder" %tmp_path

# the input data order
perm = [2, 0, 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put the correct order when instantiating the list instead of using perm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used a perm list so that the actual in_shapes an in_types list can have the same order as passed in the native model. It's just the converted onnx takes them in a different order some how. I think this is more consistent, what do you think?

input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out),
(batch, seq_length)]
input_shapes = [input_shapes[i] for i in perm]
dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out),
(batch, 'seq_length')]
dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm]
input_types = [np.float32, np.float32, np.float32]
# do a dynamic export
onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True,
dynamic_input_shapes=dynamic_input_shapes)

# step 0
step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32')
# mxnet
pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states)
# onnx
# note that we need to expand the sequence axis just like in here:
# https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831
input_onx = mx.nd.expand_dims(step_input, axis=1)
onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
onnx_inputs = [onnx_inputs[i] for i in perm]
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)

assert_almost_equal(pred, pred_onx[0])

# step >= 1
for i in range(20):
step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32')
# mxnet
pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states)
# onnx
# note that we need to concat the step_input with the previous inpus
# just like in here:
# https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828
input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1)
onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
onnx_inputs = [onnx_inputs[i] for i in perm]
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)

assert_almost_equal(pred, pred_onx[0])

verify_encoder()
verify_src_embed()
verify_tgt_embed()
verify_tgt_proj()
verify_one_step_ahead_decoder()

finally:
shutil.rmtree(tmp_path)