From b535e46f1663378659cda6ceec22d95ca48536d9 Mon Sep 17 00:00:00 2001 From: "Sevin F. Varoglu" Date: Thu, 26 May 2022 10:56:02 -0700 Subject: [PATCH] [ONNX] Add MeanVarianceNormalization op (#11444) * [ONNX] Add MeanVarianceNormalization op * Add pytest.main([__file__]) --- python/tvm/relay/frontend/onnx.py | 15 +++- tests/python/frontend/onnx/test_forward.py | 98 +--------------------- 2 files changed, 15 insertions(+), 98 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1294852ba197..30e8188a8312 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2195,6 +2195,19 @@ def _impl_v1(cls, inputs, attr, params): return _op.mean(concat, axis=0, keepdims=False) +class MeanVarianceNormalization(OnnxOpConverter): + """Operator converter for MeanVarianceNormalization.""" + + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = attr.get("axes", (0, 2, 3)) + data_mean = _op.mean(inputs[0], axis=axis, keepdims=True) + data_mean_squared = _op.power(data_mean, _expr.const(2, "float32")) + data_squared = _op.power(inputs[0], _expr.const(2, "float32")) + data_squared_mean = _op.mean(data_squared, axis=axis, keepdims=True) + return (inputs[0] - data_mean) / _op.sqrt(data_squared_mean - data_mean_squared) + + class HardSigmoid(OnnxOpConverter): """Operator converter for HardSigmoid.""" @@ -5072,7 +5085,7 @@ def _get_convert_map(opset): # 'GRUUnit' # 'ATen' # 'ImageScaler' - # 'MeanVarianceNormalization' + "MeanVarianceNormalization": MeanVarianceNormalization.get_converter(opset), # 'Crop' # 'Embedding' "Upsample": Upsample.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d6f96f0d0796..41123a254825 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5140,7 +5140,6 @@ def verify_eyelike(indata): "test_maxpool_with_argmax_2d_precomputed_pads", "test_maxpool_with_argmax_2d_precomputed_strides", "test_maxunpool_export_with_output_shape", - "test_mvn", # This test fails llvm with a lowering error: "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", "test_optional_has_element", @@ -6654,99 +6653,4 @@ def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1): if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_gemm() - test_matmul() - test_matmulinteger16() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample_nearest() - test_upsample_bilinear() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_unary_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() - test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() - test_max_roi_pool() - test_roi_align() - test_range() - test_loop() - test_size() - test_maxunpool() - test_softplus() - test_cumsum() - test_wrong_input() - test_aten() - test_index_put() - test_reverse_sequence() - test_eyelike() - test_qlinearconcat() - test_qlinearconv() - test_random_uniform() - test_convinteger() - test_batch_matmul() - test_use_nt_batch_matmul() - test_global_lppool() - test_scan() - test_random_uniform_like() - test_random_normal() - test_random_normal_like() - test_LinearRegressor() + pytest.main([__file__])