From d67ff5d6076cb132d6da6bac35ca3eccf7ae5498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BA=A6=E9=A9=B0?= Date: Thu, 28 Jul 2022 14:55:30 +0800 Subject: [PATCH 1/6] Add RNN operation for ONNX frontend. --- python/tvm/relay/frontend/onnx.py | 136 ++++++++++++++++++++- tests/python/frontend/onnx/test_forward.py | 133 +++++++++++++++++++- 2 files changed, 266 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3b5bf9acfa42..4b237b36a244 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -53,6 +53,7 @@ infer_value, lstm_cell, new_var, + rnn_cell, shape_of, try_resolve_var_to_const, unbind, @@ -2723,7 +2724,7 @@ def expand_shape(in_shape, shape): class RNN(OnnxOpConverter): - """Operator converter for RNNs such as LSTM and GRU.""" + """Operator converter for RNNs such as RNN, LSTM and GRU.""" @classmethod def _activation_helper(cls, activation, alpha, beta): @@ -2756,6 +2757,138 @@ def _activation_needs_beta(cls, activation): ] return activation.decode("utf-8") in needs_beta + @classmethod + def bidir_rnn_cell( + cls, + input_seqs, + weight_dicts, + acts, + ): + """ + Bidirectional RNN cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = rnn_cell( + input_seqs, + **weight_dicts[0], + act=acts[0], + ) + + reverse_outputs, rev_H_t = rnn_cell( + input_seqs, + **weight_dicts[1], + act=acts[1], + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) + + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + ) + + @classmethod + def _impl_v7(cls, inputs, attr, params): + # Unpack inputs, note that if optional and not provided then value will be None. + X = inputs[0] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] + # Sequence length currently unused as it can be inferred from shapes. + # sequence_lens = inputs['sequence_lens'] + Hp_0 = inputs[5] + + num_directions = infer_shape(Wp)[0] + W_dtype = infer_type(Wp).checked_type.dtype + + if num_directions not in [1, 2]: + raise ValueError("num_directions must be either 1 or 2!") + + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] + + if Hp_0 is None: + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + + if "activations" in attr: + activations = attr["activations"] + if len(activations) != num_directions: + raise NotImplementedError( + "RNN assumes num_directions activation functions are provided" + ) + alpha_loc = 0 + alphas = attr.get("activation_alpha", []) + if isinstance(alphas, float): + alphas = [alphas] + beta_loc = 0 + betas = attr.get("activation_beta", []) + if isinstance(betas, float): + betas = [betas] + acts = [] + for i in range(num_directions): + alpha = None + beta = None + activation = activations[i] + if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: + alpha = alphas[alpha_loc] + alpha_loc += 1 + if cls._activation_needs_beta(activation) and len(betas) > beta_loc: + beta = betas[beta_loc] + beta_loc += 1 + acts.append(cls._activation_helper(activation, alpha, beta)) + else: + acts = [_op.tanh, _op.tanh] + + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = unbind(X, axis=0) + + H_ts = _op.split(Hp_0, num_directions) + Ws = _op.split(Wp, num_directions) + Rs = _op.split(Rp, num_directions) + + if Bp is not None: + Bs = _op.split(Bp, num_directions) + + weights_dicts = [] + for i in range(num_directions): + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + + weights_dict["w_inp"] = _op.squeeze(Ws[i], axis=[0]) + weights_dict["w_hid"] = _op.squeeze(Rs[i], axis=[0]) + if Bp is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + weights_dict["b_inp"] = _op.squeeze(Bi, axis=[0]) + weights_dict["b_hid"] = _op.squeeze(Bh, axis=[0]) + weights_dicts.append(weights_dict) + + if num_directions == 2: + output, H = RNN.bidir_rnn_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H = rnn_cell( + input_seqs=X_steps, + **weights_dicts[0], + act=acts[0], + ) + + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) + + return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) + class LSTM(RNN): """Operator converter for LSTM""" @@ -5287,6 +5420,7 @@ def _get_convert_map(opset): "Flatten": Flatten.get_converter(opset), "LRN": LRN.get_converter(opset), # Recurrent Layers + "RNN": RNN.get_converter(opset), "LSTM": LSTM.get_converter(opset), "GRU": GRU.get_converter(opset), # defs/vision diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d68b76751184..345ab7ab7a69 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3753,7 +3753,9 @@ def verify_rnn( target=None, dev=None, ): - if rnn_type == "LSTM": + if rnn_type == "RNN": + multiplier = 1 + elif rnn_type == "LSTM": multiplier = 4 elif rnn_type == "GRU": multiplier = 3 @@ -3876,6 +3878,133 @@ def register(name, shape, proto_type): ) +@tvm.testing.parametrize_targets +def test_rnn(target, dev): + # Set seed for test reproduction + np.random.seed(137) + for directions in [1, 2]: + # No bias. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + # Non power of two. + verify_rnn( + seq_length=3, + batch_size=3, + input_size=16, + hidden_size=40, + use_bias=True, + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + # Long sequence. + verify_rnn( + seq_length=8, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + # Large hidden. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=128, + use_bias=True, + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + # Large input. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=64, + hidden_size=32, + use_bias=True, + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + + # Different activation testing. + # Default value hardsigmoid. + # TODO: onnxruntime <= v1.12.0 has wrong default value of all activation functions + # verify_rnn( + # seq_length=2, + # batch_size=1, + # input_size=16, + # hidden_size=32, + # use_bias=False, + # activations=["HardSigmoid", "Softsign"][0: directions], + # rnn_type="RNN", + # directions=directions, + # target=target, + # dev=dev, + # ) + # Multiple parametrized activations. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "LeakyRelu"][0: directions], + alphas=[2.0, 0.5][0: directions], + betas=[0.3, 0.0][0: directions], + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + # All parametrized with new Affine activation. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "Affine"][0: directions], + alphas=[2.0, 0.8][0: directions], + betas=[0.3, 0.1][0: directions], + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + + # Testing with initial state + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + rnn_type="RNN", + directions=directions, + target=target, + dev=dev, + ) + + @tvm.testing.parametrize_targets def test_lstm(target, dev): for directions in [1, 2]: @@ -5212,7 +5341,7 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", - "test_rnn_seq_length", + "test_rnn_batchwise", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_simple_rnn_batchwise", From ea38d39efeef374f99c981c50ed1d4326b9dde72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BA=A6=E9=A9=B0?= Date: Thu, 28 Jul 2022 15:36:02 +0800 Subject: [PATCH 2/6] link checks --- tests/python/frontend/onnx/test_forward.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 345ab7ab7a69..5323a6eb8e6d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3953,7 +3953,7 @@ def test_rnn(target, dev): # input_size=16, # hidden_size=32, # use_bias=False, - # activations=["HardSigmoid", "Softsign"][0: directions], + # activations=["HardSigmoid", "Softsign"][0:directions], # rnn_type="RNN", # directions=directions, # target=target, @@ -3966,9 +3966,9 @@ def test_rnn(target, dev): input_size=16, hidden_size=32, use_bias=False, - activations=["HardSigmoid", "LeakyRelu"][0: directions], - alphas=[2.0, 0.5][0: directions], - betas=[0.3, 0.0][0: directions], + activations=["HardSigmoid", "LeakyRelu"][0:directions], + alphas=[2.0, 0.5][0:directions], + betas=[0.3, 0.0][0:directions], rnn_type="RNN", directions=directions, target=target, @@ -3981,9 +3981,9 @@ def test_rnn(target, dev): input_size=16, hidden_size=32, use_bias=False, - activations=["HardSigmoid", "Affine"][0: directions], - alphas=[2.0, 0.8][0: directions], - betas=[0.3, 0.1][0: directions], + activations=["HardSigmoid", "Affine"][0:directions], + alphas=[2.0, 0.8][0:directions], + betas=[0.3, 0.1][0:directions], rnn_type="RNN", directions=directions, target=target, From c381641a87478d4aaed175ee56fa3601dec0aa55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BA=A6=E9=A9=B0?= Date: Tue, 2 Aug 2022 17:08:15 +0800 Subject: [PATCH 3/6] rm test_rnn_batchwise in unsupported_onnx_tests --- tests/python/frontend/onnx/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5323a6eb8e6d..5508e2ada35c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5341,7 +5341,6 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", - "test_rnn_batchwise", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_simple_rnn_batchwise", From 8f334c101e49f11e1f67b0fa6e7dec8d0c75df6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BA=A6=E9=A9=B0?= Date: Tue, 2 Aug 2022 17:08:34 +0800 Subject: [PATCH 4/6] merge similar codes to class methods --- python/tvm/relay/frontend/onnx.py | 226 ++++++++++-------------------- 1 file changed, 74 insertions(+), 152 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4b237b36a244..ead055aa3125 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2793,34 +2793,21 @@ def bidir_rnn_cell( ) @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] - - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype - - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] - - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + def _default_activations(cls, num_directions): + return [_op.tanh] * num_directions + @classmethod + def _get_activations(cls, attr, multiplier, num_directions, rnn_type): + """ + Activation functions + """ if "activations" in attr: activations = attr["activations"] - if len(activations) != num_directions: + if len(activations) != multiplier * num_directions: raise NotImplementedError( - "RNN assumes num_directions activation functions are provided" + "{} assumes {} * num_directions activation functions are provided".format( + rnn_type, multiplier + ) ) alpha_loc = 0 alphas = attr.get("activation_alpha", []) @@ -2831,7 +2818,7 @@ def _impl_v7(cls, inputs, attr, params): if isinstance(betas, float): betas = [betas] acts = [] - for i in range(num_directions): + for i in range(multiplier * num_directions): alpha = None beta = None activation = activations[i] @@ -2843,7 +2830,37 @@ def _impl_v7(cls, inputs, attr, params): beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) else: - acts = [_op.tanh, _op.tanh] + acts = cls._default_activations(num_directions) + return acts + + @classmethod + def _inputs_helper(cls, inputs): + """ + Process inputs + """ + # Unpack inputs, note that if optional and not provided then value will be None. + X = inputs[0] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] + # Sequence length currently unused as it can be inferred from shapes. + # sequence_lens = inputs['sequence_lens'] + Hp_0 = inputs[5] + + num_directions = infer_shape(Wp)[0] + W_dtype = infer_type(Wp).checked_type.dtype + + if num_directions not in [1, 2]: + raise ValueError("num_directions must be either 1 or 2!") + + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] + + # Initialize state if not provided. + # Otherwise remove bidirectional axis. + if Hp_0 is None: + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved X_steps = unbind(X, axis=0) @@ -2852,8 +2869,15 @@ def _impl_v7(cls, inputs, attr, params): Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) + Bs = None if Bp is not None: Bs = _op.split(Bp, num_directions) + return X_steps, H_ts, Ws, Rs, Bs, num_directions + + @classmethod + def _impl_v7(cls, inputs, attr, params): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs) + acts = cls._get_activations(attr, 1, num_directions, "RNN") weights_dicts = [] for i in range(num_directions): @@ -2863,7 +2887,7 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["w_inp"] = _op.squeeze(Ws[i], axis=[0]) weights_dict["w_hid"] = _op.squeeze(Rs[i], axis=[0]) - if Bp is not None: + if Bs is not None: Bi, Bh = _op.split(Bs[i], 2, -1) weights_dict["b_inp"] = _op.squeeze(Bi, axis=[0]) weights_dict["b_hid"] = _op.squeeze(Bh, axis=[0]) @@ -2934,74 +2958,26 @@ def bidir_lstm_cell( ) @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] - Cp_0 = inputs[6] - Pp = inputs[7] - - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype + def _default_activations(cls, num_directions): + return [_op.sigmoid, _op.tanh, _op.tanh] * num_directions - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] + @classmethod + def _impl_v7(cls, inputs, attr, params): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs) + acts = cls._get_activations(attr, 3, num_directions, "LSTM") - # Initialize state if not provided. - # Otherwise remove bidirectional axis. - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + # cell state + Cp_0 = inputs[6] if Cp_0 is None: - Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - - if "activations" in attr: - activations = attr["activations"] - if len(activations) != 3 * num_directions: - raise NotImplementedError( - f"LSTM assumes 3 * num_directions activation functions are provided" - ) - alpha_loc = 0 - alphas = attr.get("activation_alpha", []) - if isinstance(alphas, float): - alphas = [alphas] - beta_loc = 0 - betas = attr.get("activation_beta", []) - if isinstance(betas, float): - betas = [betas] - acts = [] - for i in range(3 * num_directions): - alpha = None - beta = None - activation = activations[i] - if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: - alpha = alphas[alpha_loc] - alpha_loc += 1 - if cls._activation_needs_beta(activation) and len(betas) > beta_loc: - beta = betas[beta_loc] - beta_loc += 1 - acts.append(cls._activation_helper(activation, alpha, beta)) + C_ts = _expr.TupleWrapper( + _expr.Tuple([_op.zeros_like(H_ts[i]) for i in range(num_directions)]), + num_directions + ) else: - acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions - - # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved - X_steps = unbind(X, axis=0) - - H_ts = _op.split(Hp_0, num_directions) - C_ts = _op.split(Cp_0, num_directions) - Ws = _op.split(Wp, num_directions) - Rs = _op.split(Rp, num_directions) + C_ts = _op.split(Cp_0, num_directions) - if Bp is not None: - Bs = _op.split(Bp, num_directions) + # peepholes + Pp = inputs[7] if Pp is not None: p_i, p_o, p_f = _op.split(Pp, 3, axis=1) @@ -3021,7 +2997,7 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4) weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) - if Bp is not None: + if Bs is not None: Bi, Bh = _op.split(Bs[i], 2, -1) mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4) weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) @@ -3097,70 +3073,16 @@ def bidir_gru_cell( _op.stack([fw_H_t, rev_H_t], axis=0), ) + @classmethod + def _default_activations(cls, num_directions): + return [_op.sigmoid, _op.tanh] * num_directions + @classmethod def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs) + acts = cls._get_activations(attr, 2, num_directions, "GRU") linear_before_reset = attr.get("linear_before_reset", 0) - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype - - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] - - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - - if "activations" in attr: - activations = attr["activations"] - if len(activations) != 2 * num_directions: - raise NotImplementedError( - "GRU assumes 2 * num_directions activation functions are provided" - ) - alpha_loc = 0 - alphas = attr.get("activation_alpha", []) - if isinstance(alphas, float): - alphas = [alphas] - beta_loc = 0 - betas = attr.get("activation_beta", []) - if isinstance(betas, float): - betas = [betas] - acts = [] - for i in range(2 * num_directions): - alpha = None - beta = None - activation = activations[i] - if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: - alpha = alphas[alpha_loc] - alpha_loc += 1 - if cls._activation_needs_beta(activation) and len(betas) > beta_loc: - beta = betas[beta_loc] - beta_loc += 1 - acts.append(cls._activation_helper(activation, alpha, beta)) - else: - acts = [_op.sigmoid, _op.tanh] * 2 - - # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved - X_steps = unbind(X, axis=0) - - H_ts = _op.split(Hp_0, num_directions) - Ws = _op.split(Wp, num_directions) - Rs = _op.split(Rp, num_directions) - - if Bp is not None: - Bs = _op.split(Bp, num_directions) - weights_dicts = [] for i in range(num_directions): weights_dict = {} @@ -3173,7 +3095,7 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0) matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3) weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0) - if Bp is not None: + if Bs is not None: Bi, Bh = _op.split(Bs[i], 2, -1) matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3) weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0) From 6d82810b6e9de589270698f915b902310f40d037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BA=A6=E9=A9=B0?= Date: Tue, 2 Aug 2022 19:12:40 +0800 Subject: [PATCH 5/6] implement opset 14 and refactor test_forward --- python/tvm/relay/frontend/onnx.py | 50 ++- tests/python/frontend/onnx/test_forward.py | 451 ++++++--------------- 2 files changed, 157 insertions(+), 344 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ead055aa3125..642d782a7870 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2834,7 +2834,7 @@ def _get_activations(cls, attr, multiplier, num_directions, rnn_type): return acts @classmethod - def _inputs_helper(cls, inputs): + def _inputs_helper(cls, inputs, layout): """ Process inputs """ @@ -2848,19 +2848,22 @@ def _inputs_helper(cls, inputs): Hp_0 = inputs[5] num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype if num_directions not in [1, 2]: raise ValueError("num_directions must be either 1 or 2!") - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] + if layout == 1: + X = _op.transpose(X, axes=(1, 0)) # Initialize state if not provided. - # Otherwise remove bidirectional axis. if Hp_0 is None: + W_dtype = infer_type(Wp).checked_type.dtype + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + elif layout == 1: + Hp_0 = _op.transpose(Hp_0, axes=(1, 0)) # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved X_steps = unbind(X, axis=0) @@ -2875,8 +2878,8 @@ def _inputs_helper(cls, inputs): return X_steps, H_ts, Ws, Rs, Bs, num_directions @classmethod - def _impl_v7(cls, inputs, attr, params): - X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs) + def _impl_common(cls, inputs, attr, layout): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) acts = cls._get_activations(attr, 1, num_directions, "RNN") weights_dicts = [] @@ -2911,8 +2914,20 @@ def _impl_v7(cls, inputs, attr, params): output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) H = _op.expand_dims(H, axis=0) + if layout == 1: + output = _op.transpose(output, axes=(1, 0)) + H = _op.transpose(H, axes=(1, 0)) return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) + @classmethod + def _impl_v7(cls, inputs, attr, params): + return cls._impl_common(inputs, attr, 0) + + @classmethod + def _impl_v14(cls, inputs, attr, params): + layout = attr.get("layout", 0) + return cls._impl_common(inputs, attr, layout) + class LSTM(RNN): """Operator converter for LSTM""" @@ -2962,8 +2977,8 @@ def _default_activations(cls, num_directions): return [_op.sigmoid, _op.tanh, _op.tanh] * num_directions @classmethod - def _impl_v7(cls, inputs, attr, params): - X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs) + def _impl_common(cls, inputs, attr, layout): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) acts = cls._get_activations(attr, 3, num_directions, "LSTM") # cell state @@ -2971,9 +2986,11 @@ def _impl_v7(cls, inputs, attr, params): if Cp_0 is None: C_ts = _expr.TupleWrapper( _expr.Tuple([_op.zeros_like(H_ts[i]) for i in range(num_directions)]), - num_directions + num_directions, ) else: + if layout == 1: + Cp_0 = _op.transpose(Cp_0, axes=(1, 0)) C_ts = _op.split(Cp_0, num_directions) # peepholes @@ -3030,6 +3047,10 @@ def _impl_v7(cls, inputs, attr, params): H = _op.expand_dims(H, axis=0) C = _op.expand_dims(C, axis=0) + if layout == 1: + output = _op.transpose(output, axes=(1, 0)) + H = _op.transpose(H, axes=(1, 0)) + C = _op.transpose(C, axes=(1, 0)) return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) @@ -3078,8 +3099,8 @@ def _default_activations(cls, num_directions): return [_op.sigmoid, _op.tanh] * num_directions @classmethod - def _impl_v7(cls, inputs, attr, params): - X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs) + def _impl_common(cls, inputs, attr, layout): + X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) acts = cls._get_activations(attr, 2, num_directions, "GRU") linear_before_reset = attr.get("linear_before_reset", 0) @@ -3122,6 +3143,9 @@ def _impl_v7(cls, inputs, attr, params): output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) H = _op.expand_dims(H, axis=0) + if layout == 1: + output = _op.transpose(output, axes=(1, 0)) + H = _op.transpose(H, axes=(1, 0)) return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5508e2ada35c..a0dcb71aefce 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3748,6 +3748,7 @@ def verify_rnn( use_peep=False, linear_before_reset=False, directions=1, + layout=0, rtol=1e-5, atol=1e-5, target=None, @@ -3788,7 +3789,10 @@ def register(np_arr, name, shape=None): proto_type = dtype_map[np_arr.dtype.name] input_tensors.append(helper.make_tensor_value_info(name, proto_type, shape)) - x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") + if layout == 1: + x_np = np.random.uniform(size=(batch_size, seq_length, input_size)).astype("float32") + else: + x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") w_np = np.random.uniform(size=(directions, multiplier * hidden_size, input_size)).astype( "float32" ) @@ -3810,15 +3814,25 @@ def register(np_arr, name, shape=None): sequence_np = np.repeat(seq_length, batch_size).astype("int32") register(sequence_np, "sequence_lens") - initial_h_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( - "float32" - ) + if layout == 1: + initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype( + "float32" + ) + else: + initial_h_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( + "float32" + ) register(initial_h_np, "initial_h") if rnn_type == "LSTM": - initial_c_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( - "float32" - ) + if layout == 1: + initial_c_np = np.random.uniform( + size=(batch_size, directions, hidden_size) + ).astype("float32") + else: + initial_c_np = np.random.uniform( + size=(directions, batch_size, hidden_size) + ).astype("float32") register(initial_c_np, "initial_c") if use_peep and rnn_type == "LSTM": @@ -3840,11 +3854,18 @@ def register(name, shape, proto_type): graph_outputs.append(helper.make_tensor_value_info(name, proto_type, list(shape))) output_shapes.append(list(shape)) - register("Y", [seq_length, directions, batch_size, hidden_size], TensorProto.FLOAT) - register("Y_h", [directions, batch_size, hidden_size], TensorProto.FLOAT) + if layout == 1: + register("Y", [directions, seq_length, batch_size, hidden_size], TensorProto.FLOAT) + register("Y_h", [batch_size, directions, hidden_size], TensorProto.FLOAT) + else: + register("Y", [seq_length, directions, batch_size, hidden_size], TensorProto.FLOAT) + register("Y_h", [directions, batch_size, hidden_size], TensorProto.FLOAT) if rnn_type == "LSTM": - register("Y_c", [directions, batch_size, hidden_size], TensorProto.FLOAT) + if layout == 1: + register("Y_c", [batch_size, directions, hidden_size], TensorProto.FLOAT) + else: + register("Y_c", [directions, batch_size, hidden_size], TensorProto.FLOAT) return output_names, graph_outputs, output_shapes @@ -3868,6 +3889,9 @@ def register(name, shape, proto_type): if linear_before_reset and rnn_type == "GRU": lbr_attr = helper.make_attribute("linear_before_reset", 1) rnn_node.attribute.append(lbr_attr) + if layout == 1: + layout_attr = helper.make_attribute("layout", 1) + rnn_node.attribute.append(layout_attr) graph = helper.make_graph([rnn_node], "rnn_test", inputs=input_tensors, outputs=graph_outputs) @@ -3878,135 +3902,13 @@ def register(name, shape, proto_type): ) -@tvm.testing.parametrize_targets -def test_rnn(target, dev): - # Set seed for test reproduction - np.random.seed(137) - for directions in [1, 2]: - # No bias. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - # Non power of two. - verify_rnn( - seq_length=3, - batch_size=3, - input_size=16, - hidden_size=40, - use_bias=True, - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - # Long sequence. - verify_rnn( - seq_length=8, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - # Large hidden. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=128, - use_bias=True, - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - # Large input. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=64, - hidden_size=32, - use_bias=True, - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - - # Different activation testing. - # Default value hardsigmoid. - # TODO: onnxruntime <= v1.12.0 has wrong default value of all activation functions - # verify_rnn( - # seq_length=2, - # batch_size=1, - # input_size=16, - # hidden_size=32, - # use_bias=False, - # activations=["HardSigmoid", "Softsign"][0:directions], - # rnn_type="RNN", - # directions=directions, - # target=target, - # dev=dev, - # ) - # Multiple parametrized activations. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "LeakyRelu"][0:directions], - alphas=[2.0, 0.5][0:directions], - betas=[0.3, 0.0][0:directions], - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - # All parametrized with new Affine activation. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Affine"][0:directions], - alphas=[2.0, 0.8][0:directions], - betas=[0.3, 0.1][0:directions], - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - - # Testing with initial state - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - rnn_type="RNN", - directions=directions, - target=target, - dev=dev, - ) - +def verify_rnn_helper(target, dev, rnn_type): + num_activations = 1 + if rnn_type == "GRU": + num_activations = 2 + elif rnn_type == "LSTM": + num_activations = 3 -@tvm.testing.parametrize_targets -def test_lstm(target, dev): for directions in [1, 2]: # No bias. verify_rnn( @@ -4015,7 +3917,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=32, use_bias=False, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -4027,7 +3929,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=32, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -4039,7 +3941,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=40, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -4051,7 +3953,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=32, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -4063,7 +3965,7 @@ def test_lstm(target, dev): input_size=16, hidden_size=128, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -4075,7 +3977,7 @@ def test_lstm(target, dev): input_size=64, hidden_size=32, use_bias=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, @@ -4083,50 +3985,62 @@ def test_lstm(target, dev): # Different activation testing. # Default value hardsigmoid. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Tanh", "Tanh"] * directions, - rnn_type="LSTM", - directions=directions, - target=target, - dev=dev, - ) + # TODO: onnxruntime <= v1.12.0 has wrong default value of all activation functions + if rnn_type != "RNN": + activations = ["HardSigmoid", "Tanh", "Tanh"][0:num_activations] * directions + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=activations, + rnn_type=rnn_type, + directions=directions, + target=target, + dev=dev, + ) # Multiple parametrized activations. + activations = ["HardSigmoid", "LeakyRelu", "Tanh"][0:num_activations] * directions + alphas = [2.0, 0.5, 0.0][0:num_activations] * directions + betas = [0.3, 0.0, 0.0][0:num_activations] * directions verify_rnn( seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, - activations=["HardSigmoid", "LeakyRelu", "Tanh"] * directions, - alphas=[2.0, 0.5, 0.0] * directions, - betas=[0.3, 0.0, 0.0] * directions, - rnn_type="LSTM", + activations=activations, + alphas=alphas, + betas=betas, + rnn_type=rnn_type, directions=directions, target=target, dev=dev, ) # All parametrized with new Affine activation. + activations = ["Affine", "LeakyRelu", "HardSigmoid"] + alphas = [0.8, 2.0, 0.5] + betas = [0.0, 0.3, 0.0] + activations = activations[0:num_activations] * directions + alphas = alphas[0:num_activations] * directions + betas = betas[0:num_activations] * directions verify_rnn( seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, - activations=["HardSigmoid", "LeakyRelu", "Affine"] * directions, - alphas=[2.0, 0.5, 0.8] * directions, - betas=[0.3, 0.1, 0.0] * directions, - rnn_type="LSTM", + activations=activations, + alphas=alphas, + betas=betas, + rnn_type=rnn_type, directions=directions, target=target, dev=dev, ) - # Testing with initial state and peepholes + # Testing with initial state verify_rnn( seq_length=2, batch_size=1, @@ -4134,182 +4048,57 @@ def test_lstm(target, dev): hidden_size=32, use_bias=True, use_initial_state=True, - rnn_type="LSTM", + rnn_type=rnn_type, directions=directions, target=target, dev=dev, ) - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - use_peep=True, - rnn_type="LSTM", - directions=directions, - target=target, - dev=dev, - ) + # Testing layout + # TODO: onnxruntime <= 1.12.0 doesn't support layout == 1 + # verify_rnn( + # seq_length=2, + # batch_size=1, + # input_size=16, + # hidden_size=32, + # use_bias=True, + # rnn_type="RNN", + # directions=directions, + # layout=1, + # target=target, + # dev=dev, + # ) + + # Testing with peepholes + if rnn_type == "LSTM": + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + use_peep=True, + rnn_type="LSTM", + directions=directions, + target=target, + dev=dev, + ) @tvm.testing.parametrize_targets -def test_gru(target, dev): - # Set seed for test reproduction - np.random.seed(137) - for directions in [1, 2]: - # No bias. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # large batch. linear before reset - verify_rnn( - seq_length=4, - batch_size=8, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - linear_before_reset=True, - directions=directions, - target=target, - dev=dev, - ) - # Non power of two. - verify_rnn( - seq_length=3, - batch_size=3, - input_size=16, - hidden_size=40, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Long sequence. - verify_rnn( - seq_length=8, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Large hidden. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=128, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Large input. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=64, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) +def test_rnn(target, dev): + verify_rnn_helper(target, dev, "RNN") - # Different activation testing. - # Default value hardsigmoid. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Softsign"] * directions, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) - # Multiple parametrized activations. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "LeakyRelu"] * directions, - alphas=[2.0, 0.5] * directions, - betas=[0.3, 0.0] * directions, - rnn_type="GRU", - directions=directions, - rtol=1e-8, - atol=1e-8, - target=target, - dev=dev, - ) - # All parametrized with new Affine activation. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Affine"] * directions, - alphas=[2.0, 0.8] * directions, - betas=[0.3, 0.1] * directions, - rnn_type="GRU", - directions=directions, - rtol=1e-8, - atol=1e-8, - target=target, - dev=dev, - ) - # Testing with initial state - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - rnn_type="GRU", - directions=directions, - rtol=1e-6, - atol=1e-6, - target=target, - dev=dev, - ) +@tvm.testing.parametrize_targets +def test_lstm(target, dev): + verify_rnn_helper(target, dev, "LSTM") + + +@tvm.testing.parametrize_targets +def test_gru(target, dev): + verify_rnn_helper(target, dev, "GRU") @tvm.testing.parametrize_targets From 396e7dc9d4dda9f53894a0482c8f6c34c97bcad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E4=BA=A6=E9=A9=B0?= Date: Wed, 3 Aug 2022 09:30:40 +0800 Subject: [PATCH 6/6] reformat verify_rnn_helper --- tests/python/frontend/onnx/test_forward.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a0dcb71aefce..9d1817b7a310 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4019,12 +4019,9 @@ def verify_rnn_helper(target, dev, rnn_type): dev=dev, ) # All parametrized with new Affine activation. - activations = ["Affine", "LeakyRelu", "HardSigmoid"] - alphas = [0.8, 2.0, 0.5] - betas = [0.0, 0.3, 0.0] - activations = activations[0:num_activations] * directions - alphas = alphas[0:num_activations] * directions - betas = betas[0:num_activations] * directions + activations = ["Affine", "LeakyRelu", "HardSigmoid"][0:num_activations] * directions + alphas = [0.8, 2.0, 0.5][0:num_activations] * directions + betas = [0.0, 0.3, 0.0][0:num_activations] * directions verify_rnn( seq_length=2, batch_size=1,