Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Relay][Frontend][ONNX] Add LayerNormalization operator (apache#13074)
Browse files Browse the repository at this point in the history
* [Relay][Frontend][ONNX] Add LayerNormalization operator

* Include mean in variance to reduce the number of expressions if already exists

* Fix lint

Co-authored-by: Ehsan M. Kermani <ehsanmok@users.noreply.github.com>
  • Loading branch information
2 people authored and xinetzone committed Nov 25, 2022
1 parent 4046693 commit 7fca675
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
37 changes: 35 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,11 @@ def flatten_to_nd(x, x_shape, nd=3):


def layer_norm(x, eps, gamma, beta):
"""Common function to handle layer norm"""
eps_dtype = infer_type(x).checked_type.dtype
"""A common function to handle layer norm.
Use LayerNormalization for the actual onnx op.
"""
eps_dtype = infer_type(x).checked_type.dtype
u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
Expand Down Expand Up @@ -944,6 +946,36 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)


class LayerNormalization(OnnxOpConverter):
"""Operator converter for LayerNormalization from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v17(cls, inputs, attr, params):
x = inputs[0]
gamma = inputs[1]
beta = inputs[2]
axis = attr.get("axis", -1)
eps = attr.get("epsilon", 1e-5)
# according to the onnx doc, given the int axis (default -1)
# to compute the mean and inv_stdev which are of dim [d[0], ..., d[axis-1], 1, ..., 1]
# the actual computation is over (axis, ..., rank(x) - 1) axes
# see https://github.com/onnx/onnx/blob/main/docs/Changelog.md#layernormalization-17
rank = len(infer_shape(x))
axis = tuple(range(axis, rank)) if axis >= 0 else tuple(range(rank + axis, rank))
dtype = infer_type(x).checked_type.dtype
mean = _op.mean(x, axis, keepdims=True)
var = _op.variance(x, axis, keepdims=True, with_mean=mean)
inv_stdev = _op.divide(
_op.const(1, dtype=dtype), _op.sqrt(_op.add(var, _op.const(eps, dtype=dtype)))
)
x_norm = _op.multiply(_op.subtract(x, mean), inv_stdev)
ln = _op.multiply(x_norm, gamma)
if beta is not None:
ln = _op.add(ln, beta)

return _expr.TupleWrapper(_expr.Tuple([ln, mean, inv_stdev]), 3)


class EmbedLayerNormalization(OnnxOpConverter):
"""Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.
Expand Down Expand Up @@ -5336,6 +5368,7 @@ def _get_convert_map(opset):
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
"LayerNormalization": LayerNormalization.get_converter(opset),
# TODO: We need a better way to handle different domains, in case
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
# are in the `com.microsoft` domain.
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
return _make.mean(data, axis, keepdims, exclude)


def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False):
def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False, with_mean=None):
"""Computes the variance of data over given axes.
Parameters
Expand All @@ -347,13 +347,16 @@ def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False):
unbiased : bool
If this is set to True, the unbiased estimation will be used.
with_mean : Optional[relay.Expr]
To compute variance given an already computed mean
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
m = mean(data, axis, True, exclude) if with_mean is None else with_mean
return _make._variance(data, m, axis, keepdims, exclude, unbiased)


Expand Down
19 changes: 0 additions & 19 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5295,25 +5295,6 @@ def verify_eyelike(indata, dynamic=False):
"test_identity_sequence",
"test_if_opt",
"test_if_seq",
"test_layer_normalization_2d_axis0",
"test_layer_normalization_2d_axis1",
"test_layer_normalization_2d_axis_negative_1",
"test_layer_normalization_2d_axis_negative_2",
"test_layer_normalization_3d_axis0_epsilon",
"test_layer_normalization_3d_axis1_epsilon",
"test_layer_normalization_3d_axis2_epsilon",
"test_layer_normalization_3d_axis_negative_1_epsilon",
"test_layer_normalization_3d_axis_negative_2_epsilon",
"test_layer_normalization_3d_axis_negative_3_epsilon",
"test_layer_normalization_4d_axis0",
"test_layer_normalization_4d_axis1",
"test_layer_normalization_4d_axis2",
"test_layer_normalization_4d_axis3",
"test_layer_normalization_4d_axis_negative_1",
"test_layer_normalization_4d_axis_negative_2",
"test_layer_normalization_4d_axis_negative_3",
"test_layer_normalization_4d_axis_negative_4",
"test_layer_normalization_default_axis",
"test_loop11",
"test_loop13_seq",
"test_loop16_seq_none",
Expand Down

0 comments on commit 7fca675

Please sign in to comment.