From f2dc061e35fb93a4f5cc7d8dd5186f4a6e09c3b8 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 19 Mar 2021 00:51:38 +0000 Subject: [PATCH 1/6] add ops --- .../contrib/onnx/mx2onnx/_op_translations.py | 109 +++++++++++++++++- 1 file changed, 104 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 462564a85071..43b95249121e 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2064,14 +2064,80 @@ def convert_broadcast_lesser(node, **kwargs): """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator and return the created node. """ - return create_basic_op_node('Less', node, kwargs) + from onnx.helper import make_node, make_tensor + name, input_nodes, _ = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] + dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + + nodes = [ + make_node('Less', [input_nodes[0], input_nodes[1]], [name+'_lt']), + make_node('Cast', [name+'_lt'], [name], to=dtype_t) + ] + + return nodes + + +@mx_op.register("broadcast_lesser_equal") +def convert_broadcast_lesser(node, **kwargs): + """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator + and return the created node. + """ + from onnx.helper import make_node, make_tensor + name, input_nodes, _ = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] + dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + + nodes = [ + make_node('LessOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']), + make_node('Cast', [name+'_lt'], [name], to=dtype_t) + ] + + return nodes + + +@mx_op.register("broadcast_greater_equal") +def convert_broadcast_lesser(node, **kwargs): + """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator + and return the created node. + """ + from onnx.helper import make_node, make_tensor + name, input_nodes, _ = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] + dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + + nodes = [ + make_node('GreaterOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']), + make_node('Cast', [name+'_lt'], [name], to=dtype_t) + ] + + return nodes + @mx_op.register("broadcast_greater") def convert_broadcast_greater(node, **kwargs): """Map MXNet's broadcast_greater operator attributes to onnx's Greater operator and return the created node. """ - return create_basic_op_node('Greater', node, kwargs) + from onnx.helper import make_node, make_tensor + name, input_nodes, _ = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] + dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + + nodes = [ + make_node('Greater', [input_nodes[0], input_nodes[1]], [name+'_gt']), + make_node('Cast', [name+'_gt'], [name], to=dtype_t) + ] + + return nodes + @mx_op.register("broadcast_equal") def convert_broadcast_equal(node, **kwargs): @@ -2498,7 +2564,6 @@ def convert_layer_norm(node, **kwargs): axes = int(attrs.get('axis', -1)) eps = attrs.get('eps', 9.99999975e-06) - create_tensor([axes], name+"_axes", kwargs["initializer"]) create_tensor([axes+1], name+"_axes+1", kwargs["initializer"]) create_const_scalar_node(name+'_0_s', np.int64(0), kwargs) @@ -2519,7 +2584,11 @@ def convert_layer_norm(node, **kwargs): if axes == -1: nodes += [ make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), - make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name=name) + # make_node("Add", [name+"_mul0_out", input_nodes[2]], [name]) + # the Add operator triggers a weird NaN issue in onnxruntime + # a workaround is to use Neg + Sub + make_node('Neg', [input_nodes[2]], [name+'_neg']), + make_node("Sub", [name+"_mul0_out", name+'_neg'], [name]) ] else: nodes += [ @@ -4301,7 +4370,7 @@ def convert_RNN(node, **kwargs): @mx_op.register('_rnn_param_concat') def convert_rnn_param_concat(node, **kwargs): - """Map MXNet’s _rnn_param_concat operator + """Map MXNet's _rnn_param_concat operator """ from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) @@ -4313,3 +4382,33 @@ def convert_rnn_param_concat(node, **kwargs): ] return nodes + + +@mx_op.register('_contrib_div_sqrt_dim') +def convert_contrib_div_sqrt_dim(node, **kwargs): + """Map MXNet's _contrib_div_sqrt_dim operator + """ + from onnx.helper import make_node + from onnx import TensorProto + name, input_nodes, _ = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] + dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([1], name+'_1_f', kwargs['initializer'], dtype=dtype) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Sub', [name+'_dim', name+'_1'], [name+'_dim_m1']), + make_node('Slice', [name+'_shape', name+'_dim_m1', name+'_dim', name+'_0'], [name+'_c_']), + make_node('Cast', [name+'_c_'], [name+'_c'], to=dtype_t), + make_node('Sqrt', [name+'_c'], [name+'_c_sqrt']), + make_node('Div', [name+'_1_f', name+'_c_sqrt'], [name+'_1_over_c_sqrt']), + make_node('Mul', [input_nodes[0], name+'_1_over_c_sqrt'], [name]) + ] + + return nodes + From 8aabc6817ba4c44c2fc470813d07ce1de06eb9cc Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 19 Mar 2021 00:55:55 +0000 Subject: [PATCH 2/6] add transformer test --- tests/python-pytest/onnx/test_onnxruntime.py | 167 +++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index e2a8329dd45d..c068d74e940b 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name): finally: shutil.rmtree(tmp_path) + + +@with_seed() +@pytest.mark.parametrize('model_name', ['transformer_en_de_512']) +def test_transformer_pretrained_inference_onnxruntime(tmp_path, model): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + dataset = 'WMT2014' + ctx = mx.cpu(0) + model, _, _ = nlp.model.get_model( + name=model, + ctx=ctx, + pretrained=True, + dataset_name=dataset) + + model.hybridize(static_alloc=False) + + batch = 7 + seq_length = 16 + C_in = 512 + C_out = 512 + src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') + step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32') + src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + + encoder_outputs, encoder_additional_outputs = model.encode(src, + valid_length=src_valid_length) + + decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length) + + step_output, states, additional_outputs = model.decode_step(step_input, decoder_states) + + # skip export of 'decoder' as it's used for training only + for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed', + 'tgt_proj']: + + prefix = "%s/%s" %(tmp_path, component) + component = getattr(model, component) + component.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + def export_to_onnx(prefix, input_shapes, input_types, **kwargs): + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types, + onnx_file, **kwargs) + + def onnx_runtime_predict(onnx_file, onnx_inputs): + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) + for i in range(len(onnx_inputs))) + return session.run(None, input_dict) + + def verify_encoder(): + inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32') + valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + pred = model.encoder(inputs, valid_length=valid_length) + + prefix = "%s/encoder" %tmp_path + input_shapes = [(batch, seq_length, C_in), (batch,)] + input_types = [np.float32, np.float32] + onnx_file = export_to_onnx(prefix, input_shapes, input_types) + onnx_inputs = [inputs, valid_length] + pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) + + assert_almost_equal(pred[0], pred_onx[0]) + + def verify_src_embed(): + src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') + pred = model.src_embed(src) + + prefix = "%s/src_embed" %tmp_path + input_shapes = [(batch, seq_length)] + input_types = [np.float32] + onnx_file = export_to_onnx(prefix, input_shapes, input_types) + onnx_inputs = [src] + pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) + + assert_almost_equal(pred, pred_onx[0]) + + def verify_tgt_embed(): + tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') + pred = model.tgt_embed(tgt) + + prefix = "%s/tgt_embed" %tmp_path + input_shapes = [(batch, seq_length)] + input_types = [np.float32] + onnx_file = export_to_onnx(prefix, input_shapes, input_types) + onnx_inputs = [tgt] + pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) + + assert_almost_equal(pred, pred_onx[0]) + + def verify_tgt_proj(): + decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out), + dtype='float32') + pred = model.tgt_proj(decoder_out) + + prefix = "%s/tgt_proj" %tmp_path + input_shapes = [(batch, seq_length, C_out)] + input_types = [np.float32] + onnx_file = export_to_onnx(prefix, input_shapes, input_types) + onnx_inputs = [decoder_out] + pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) + + assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03) + + def verify_one_step_ahead_decoder(): + prefix = "%s/one_step_ahead_decoder" %tmp_path + + # the input data order + perm = [2, 0, 1] + input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out), + (batch, seq_length)] + input_shapes = [input_shapes[i] for i in perm] + dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out), + (batch, 'seq_length')] + dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm] + input_types = [np.float32, np.float32, np.float32] + # do a dynamic export + onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, + dynamic_input_shapes=dynamic_input_shapes) + + # step 0 + step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32') + # mxnet + pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states) + # onnx + # note that we need to expand the sequence axis just like in here: + # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831 + input_onx = mx.nd.expand_dims(step_input, axis=1) + onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]] + onnx_inputs = [onnx_inputs[i] for i in perm] + pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) + + assert_almost_equal(pred, pred_onx[0]) + + # step >= 1 + for i in range(20): + step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32') + # mxnet + pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states) + # onnx + # note that we need to concat the step_input with the previous inpus + # just like in here: + # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828 + input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1) + onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]] + onnx_inputs = [onnx_inputs[i] for i in perm] + pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) + + assert_almost_equal(pred, pred_onx[0]) + + verify_encoder() + verify_src_embed() + verify_tgt_embed() + verify_tgt_proj() + verify_one_step_ahead_decoder() + + finally: + shutil.rmtree(tmp_path) From 0abc36925b657361d1917aa99f166676dfbb50dd Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 19 Mar 2021 00:58:02 +0000 Subject: [PATCH 3/6] fix test --- tests/python-pytest/onnx/test_onnxruntime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index c068d74e940b..249bba6fc341 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -992,14 +992,14 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name): @with_seed() @pytest.mark.parametrize('model_name', ['transformer_en_de_512']) -def test_transformer_pretrained_inference_onnxruntime(tmp_path, model): +def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name): tmp_path = str(tmp_path) try: import gluonnlp as nlp dataset = 'WMT2014' ctx = mx.cpu(0) model, _, _ = nlp.model.get_model( - name=model, + name=model_name, ctx=ctx, pretrained=True, dataset_name=dataset) From 6454f79516f8c29c59def9c772b7a05e261a4501 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 19 Mar 2021 18:12:19 +0000 Subject: [PATCH 4/6] add unit test --- .../contrib/onnx/mx2onnx/_op_translations.py | 4 +-- tests/python-pytest/onnx/test_onnxruntime.py | 2 +- tests/python-pytest/onnx/test_operators.py | 26 +++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 43b95249121e..84c599be786a 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2112,8 +2112,8 @@ def convert_broadcast_lesser(node, **kwargs): dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] nodes = [ - make_node('GreaterOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']), - make_node('Cast', [name+'_lt'], [name], to=dtype_t) + make_node('GreaterOrEqual', [input_nodes[0], input_nodes[1]], [name+'_gt']), + make_node('Cast', [name+'_gt'], [name], to=dtype_t) ] return nodes diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 249bba6fc341..ca4c0fed3fd9 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -723,7 +723,7 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name): @with_seed() @pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650), ('standard_lstm_lm_1500', 1500)]) -@pytest.mark.parametrize('seq_length', [16, 32]) +@pytest.mark.parametrize('seq_length', [64, 128]) def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length): try: import gluonnlp as nlp diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 36b687a0eadc..220f259beb46 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1234,3 +1234,29 @@ def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, ba state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) op_export_test('rnn', M, [x, param, state, cell], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64']) +@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))]) +def test_onnx_export_broadcast_lesser_equal(tmp_path, dtype, shapes): + A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype) + B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype) + M = def_model('broadcast_lesser_equal') + op_export_test('broadcast_lesser_equal', M, [A, B], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64']) +@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))]) +def test_onnx_export_broadcast_greater_equal(tmp_path, dtype, shapes): + A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype) + B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype) + M = def_model('broadcast_greater_equal') + op_export_test('broadcast_greater_equal', M, [A, B], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)]) +def test_onnx_export_contrib_div_sqrt_dim(tmp_path, dtype, shape): + A = mx.nd.random.uniform(-100, 100, shape).astype(dtype) + M = def_model('contrib.div_sqrt_dim') + op_export_test('contrib_div_sqrt_dim', M, [A], tmp_path) From b7f5f24a0348eda4452bd8c17fbbfc11dc68f77c Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 19 Mar 2021 20:19:36 +0000 Subject: [PATCH 5/6] fix sanity --- .../contrib/onnx/mx2onnx/_op_translations.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 84c599be786a..0903376c11d9 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2064,7 +2064,7 @@ def convert_broadcast_lesser(node, **kwargs): """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator and return the created node. """ - from onnx.helper import make_node, make_tensor + from onnx.helper import make_node name, input_nodes, _ = get_inputs(node, kwargs) input_dtypes = get_input_dtypes(node, kwargs) @@ -2080,11 +2080,10 @@ def convert_broadcast_lesser(node, **kwargs): @mx_op.register("broadcast_lesser_equal") -def convert_broadcast_lesser(node, **kwargs): - """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator - and return the created node. +def convert_broadcast_lesser_equal(node, **kwargs): + """Map MXNet's broadcast_lesser_equal operator """ - from onnx.helper import make_node, make_tensor + from onnx.helper import make_node name, input_nodes, _ = get_inputs(node, kwargs) input_dtypes = get_input_dtypes(node, kwargs) @@ -2100,11 +2099,10 @@ def convert_broadcast_lesser(node, **kwargs): @mx_op.register("broadcast_greater_equal") -def convert_broadcast_lesser(node, **kwargs): - """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator - and return the created node. +def convert_broadcast_greater_equal(node, **kwargs): + """Map MXNet's broadcast_greater_equal operator """ - from onnx.helper import make_node, make_tensor + from onnx.helper import make_node name, input_nodes, _ = get_inputs(node, kwargs) input_dtypes = get_input_dtypes(node, kwargs) @@ -2124,7 +2122,7 @@ def convert_broadcast_greater(node, **kwargs): """Map MXNet's broadcast_greater operator attributes to onnx's Greater operator and return the created node. """ - from onnx.helper import make_node, make_tensor + from onnx.helper import make_node name, input_nodes, _ = get_inputs(node, kwargs) input_dtypes = get_input_dtypes(node, kwargs) @@ -4389,7 +4387,6 @@ def convert_contrib_div_sqrt_dim(node, **kwargs): """Map MXNet's _contrib_div_sqrt_dim operator """ from onnx.helper import make_node - from onnx import TensorProto name, input_nodes, _ = get_inputs(node, kwargs) input_dtypes = get_input_dtypes(node, kwargs) @@ -4411,4 +4408,3 @@ def convert_contrib_div_sqrt_dim(node, **kwargs): ] return nodes - From 9826564fac33ba71d6ba67b04a302ee4a37ae75c Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 19 Mar 2021 22:16:44 +0000 Subject: [PATCH 6/6] add to ci --- ci/docker/runtime_functions.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index a3484fc04176..9bfc8418b6ad 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1274,6 +1274,7 @@ integrationtest_ubuntu_cpu_onnx() { pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_action_recognition_model_inference_onnxruntime[inceptionv3_kinetics400] pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_bert_inference_onnxruntime pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_cv_inference_onnxruntime + pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_transformer_pretrained_inference_onnxruntime } integrationtest_ubuntu_gpu_python() {