-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[v1.x] Onnx Support for Transformer #20048
Changes from all commits
f2dc061
40d3ad4
8aabc68
0abc369
6454f79
b7f5f24
9826564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -723,7 +723,7 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name): | |
@with_seed() | ||
@pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650), | ||
('standard_lstm_lm_1500', 1500)]) | ||
@pytest.mark.parametrize('seq_length', [16, 32]) | ||
@pytest.mark.parametrize('seq_length', [64, 128]) | ||
def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length): | ||
try: | ||
import gluonnlp as nlp | ||
|
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name): | |
|
||
finally: | ||
shutil.rmtree(tmp_path) | ||
|
||
|
||
@with_seed() | ||
@pytest.mark.parametrize('model_name', ['transformer_en_de_512']) | ||
def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name): | ||
tmp_path = str(tmp_path) | ||
try: | ||
import gluonnlp as nlp | ||
dataset = 'WMT2014' | ||
ctx = mx.cpu(0) | ||
model, _, _ = nlp.model.get_model( | ||
name=model_name, | ||
ctx=ctx, | ||
pretrained=True, | ||
dataset_name=dataset) | ||
|
||
model.hybridize(static_alloc=False) | ||
|
||
batch = 7 | ||
seq_length = 16 | ||
C_in = 512 | ||
C_out = 512 | ||
src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, does not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it's float in the original mxnet model too. This should not matter I think because the operator will apply ceiling/flooring |
||
step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32') | ||
src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32') | ||
|
||
encoder_outputs, encoder_additional_outputs = model.encode(src, | ||
valid_length=src_valid_length) | ||
|
||
decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length) | ||
|
||
step_output, states, additional_outputs = model.decode_step(step_input, decoder_states) | ||
|
||
# skip export of 'decoder' as it's used for training only | ||
for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed', | ||
'tgt_proj']: | ||
|
||
prefix = "%s/%s" %(tmp_path, component) | ||
component = getattr(model, component) | ||
component.export(prefix) | ||
sym_file = "%s-symbol.json" % prefix | ||
params_file = "%s-0000.params" % prefix | ||
onnx_file = "%s.onnx" % prefix | ||
|
||
def export_to_onnx(prefix, input_shapes, input_types, **kwargs): | ||
sym_file = "%s-symbol.json" % prefix | ||
params_file = "%s-0000.params" % prefix | ||
onnx_file = "%s.onnx" % prefix | ||
return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types, | ||
onnx_file, **kwargs) | ||
|
||
def onnx_runtime_predict(onnx_file, onnx_inputs): | ||
ses_opt = onnxruntime.SessionOptions() | ||
ses_opt.log_severity_level = 3 | ||
session = onnxruntime.InferenceSession(onnx_file, ses_opt) | ||
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) | ||
for i in range(len(onnx_inputs))) | ||
return session.run(None, input_dict) | ||
|
||
def verify_encoder(): | ||
inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32') | ||
valid_length = mx.nd.array([seq_length] * batch, dtype='float32') | ||
pred = model.encoder(inputs, valid_length=valid_length) | ||
|
||
prefix = "%s/encoder" %tmp_path | ||
input_shapes = [(batch, seq_length, C_in), (batch,)] | ||
input_types = [np.float32, np.float32] | ||
onnx_file = export_to_onnx(prefix, input_shapes, input_types) | ||
onnx_inputs = [inputs, valid_length] | ||
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) | ||
|
||
assert_almost_equal(pred[0], pred_onx[0]) | ||
|
||
def verify_src_embed(): | ||
src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') | ||
pred = model.src_embed(src) | ||
|
||
prefix = "%s/src_embed" %tmp_path | ||
input_shapes = [(batch, seq_length)] | ||
input_types = [np.float32] | ||
onnx_file = export_to_onnx(prefix, input_shapes, input_types) | ||
onnx_inputs = [src] | ||
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) | ||
|
||
assert_almost_equal(pred, pred_onx[0]) | ||
|
||
def verify_tgt_embed(): | ||
tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') | ||
pred = model.tgt_embed(tgt) | ||
|
||
prefix = "%s/tgt_embed" %tmp_path | ||
input_shapes = [(batch, seq_length)] | ||
input_types = [np.float32] | ||
onnx_file = export_to_onnx(prefix, input_shapes, input_types) | ||
onnx_inputs = [tgt] | ||
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) | ||
|
||
assert_almost_equal(pred, pred_onx[0]) | ||
|
||
def verify_tgt_proj(): | ||
decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out), | ||
dtype='float32') | ||
pred = model.tgt_proj(decoder_out) | ||
|
||
prefix = "%s/tgt_proj" %tmp_path | ||
input_shapes = [(batch, seq_length, C_out)] | ||
input_types = [np.float32] | ||
onnx_file = export_to_onnx(prefix, input_shapes, input_types) | ||
onnx_inputs = [decoder_out] | ||
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) | ||
|
||
assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03) | ||
|
||
def verify_one_step_ahead_decoder(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sxjscience would you help take a quick look at this func thanks! |
||
prefix = "%s/one_step_ahead_decoder" %tmp_path | ||
|
||
# the input data order | ||
perm = [2, 0, 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we put the correct order when instantiating the list instead of using perm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used a perm list so that the actual in_shapes an in_types list can have the same order as passed in the native model. It's just the converted onnx takes them in a different order some how. I think this is more consistent, what do you think? |
||
input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out), | ||
(batch, seq_length)] | ||
input_shapes = [input_shapes[i] for i in perm] | ||
dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out), | ||
(batch, 'seq_length')] | ||
dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm] | ||
input_types = [np.float32, np.float32, np.float32] | ||
# do a dynamic export | ||
onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, | ||
dynamic_input_shapes=dynamic_input_shapes) | ||
|
||
# step 0 | ||
step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32') | ||
# mxnet | ||
pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states) | ||
# onnx | ||
# note that we need to expand the sequence axis just like in here: | ||
# https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831 | ||
input_onx = mx.nd.expand_dims(step_input, axis=1) | ||
onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]] | ||
onnx_inputs = [onnx_inputs[i] for i in perm] | ||
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) | ||
|
||
assert_almost_equal(pred, pred_onx[0]) | ||
|
||
# step >= 1 | ||
for i in range(20): | ||
step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32') | ||
# mxnet | ||
pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states) | ||
# onnx | ||
# note that we need to concat the step_input with the previous inpus | ||
# just like in here: | ||
# https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828 | ||
input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1) | ||
onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]] | ||
onnx_inputs = [onnx_inputs[i] for i in perm] | ||
pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) | ||
|
||
assert_almost_equal(pred, pred_onx[0]) | ||
|
||
verify_encoder() | ||
verify_src_embed() | ||
verify_tgt_embed() | ||
verify_tgt_proj() | ||
verify_one_step_ahead_decoder() | ||
|
||
finally: | ||
shutil.rmtree(tmp_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are C_in and C_out? Should we also test when
C_in != C_out
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can refer to this file https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py for C_in and C_out. Those are defined in the pretrained model thus we need to set it the same as in the pretrained model