Skip to content

Commit

Permalink
[ONNX] Add MeanVarianceNormalization op (#11444)
Browse files Browse the repository at this point in the history
* [ONNX] Add MeanVarianceNormalization op

* Add pytest.main([__file__])
  • Loading branch information
sfvaroglu authored May 26, 2022
1 parent 52df2e8 commit b535e46
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 98 deletions.
15 changes: 14 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,19 @@ def _impl_v1(cls, inputs, attr, params):
return _op.mean(concat, axis=0, keepdims=False)


class MeanVarianceNormalization(OnnxOpConverter):
"""Operator converter for MeanVarianceNormalization."""

@classmethod
def _impl_v13(cls, inputs, attr, params):
axis = attr.get("axes", (0, 2, 3))
data_mean = _op.mean(inputs[0], axis=axis, keepdims=True)
data_mean_squared = _op.power(data_mean, _expr.const(2, "float32"))
data_squared = _op.power(inputs[0], _expr.const(2, "float32"))
data_squared_mean = _op.mean(data_squared, axis=axis, keepdims=True)
return (inputs[0] - data_mean) / _op.sqrt(data_squared_mean - data_mean_squared)


class HardSigmoid(OnnxOpConverter):
"""Operator converter for HardSigmoid."""

Expand Down Expand Up @@ -5072,7 +5085,7 @@ def _get_convert_map(opset):
# 'GRUUnit'
# 'ATen'
# 'ImageScaler'
# 'MeanVarianceNormalization'
"MeanVarianceNormalization": MeanVarianceNormalization.get_converter(opset),
# 'Crop'
# 'Embedding'
"Upsample": Upsample.get_converter(opset),
Expand Down
98 changes: 1 addition & 97 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5140,7 +5140,6 @@ def verify_eyelike(indata):
"test_maxpool_with_argmax_2d_precomputed_pads",
"test_maxpool_with_argmax_2d_precomputed_strides",
"test_maxunpool_export_with_output_shape",
"test_mvn",
# This test fails llvm with a lowering error:
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_optional_has_element",
Expand Down Expand Up @@ -6654,99 +6653,4 @@ def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1):


if __name__ == "__main__":
test_flatten()
test_reshape()
test_shape()
test_expand()
test_power()
test_squeeze()
test_unsqueeze()
test_slice()
test_floor()
test_ceil()
test_round()
test_isinf()
test_isnan()
test_clip()
test_clip_min_max_as_inputs()
test_onehot()
test_gemm()
test_matmul()
test_matmulinteger16()
test_gather()
test_gatherelements()
test_gather_nd()
test_scatter()
test_lrn()
test_instance_norm()
test_upsample_nearest()
test_upsample_bilinear()
test_forward_min()
test_forward_max()
test_forward_mean()
test_forward_hardsigmoid()
test_forward_arg_min_max()
test_softmax()
test_constantofshape()
test_all_reduce_funcs()
test_pad()
test_split()
test_binary_ops()
test_unary_ops()
test_leaky_relu()
test_elu()
test_selu()
test_prelu()
test_ThresholdedRelu()
test_LogSoftmax()
test_resnet()
test_inception()
test_densenet()
test_sign()
test_not()
test_and()
test_tile()
test_erf()
test_where()
test_or()
test_depth_to_space()
test_space_to_depth()
test_batch_norm()
test_batch_norm_dynamic_subgraph()
test_conv()
test_convtranspose()
test_unsqueeze_constant()
test_pooling()
test_lppool()
test_lstm()
test_gru()
test_resize()
test_nonzero()
test_topk()
test_mod()
test_xor()
test_max_roi_pool()
test_roi_align()
test_range()
test_loop()
test_size()
test_maxunpool()
test_softplus()
test_cumsum()
test_wrong_input()
test_aten()
test_index_put()
test_reverse_sequence()
test_eyelike()
test_qlinearconcat()
test_qlinearconv()
test_random_uniform()
test_convinteger()
test_batch_matmul()
test_use_nt_batch_matmul()
test_global_lppool()
test_scan()
test_random_uniform_like()
test_random_normal()
test_random_normal_like()
test_LinearRegressor()
pytest.main([__file__])

0 comments on commit b535e46

Please sign in to comment.