Skip to content

Commit

Permalink
[Pytorch] add aten::rnn_tanh, aten::rnn_relu (apache#12017)
Browse files Browse the repository at this point in the history
* emptycommit 2nd try

* dev

* comments

* format

* format

Co-authored-by: yuanfz <42092999+FZYUAN-1@users.noreply.github.com>
  • Loading branch information
2 people authored and Mikael Sevenier committed Jul 26, 2022
1 parent 0acda96 commit 067ea66
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 1 deletion.
40 changes: 40 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,46 @@ def unbind(data, axis=0):
return _expr.TupleWrapper(_expr.Tuple(ret), selections)


def rnn_cell(
input_seqs, hidden_state, w_inp, w_hid, b_inp=None, b_hid=None, backwards=False, act=_op.tanh
):
"""
Common implementation of RNN cell for all frontends of TVM
Parameters
----------
input_seqs : List[relay.Expr]
The sequence of input tensors
Input tensor should be 2d while issue #8412 is not resolved
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch_size, hidden_size)
w_inp, w_hid: relay.Expr
weight matrices. shape = (hidden_size, feature_size), (hidden_size, feature_size)
b_inp, b_hid : relay.Expr
bias matrices. The same order of internal parts as for weights. shape = (1 * hidden_size)
backwards : bool
Flag for reverse pass of RNN
act : relay.op
activation function. It is tanh by default.
Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""
outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
xwt = _op.nn.dense(x_t, w_inp)
hwt = _op.nn.dense(hidden_state, w_hid)
if b_inp is not None and b_hid is not None:
xwt += b_inp
hwt += b_hid
hidden_state = act(xwt + hwt)
outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]
return outputs_list, hidden_state


def gru_cell(
input_seqs,
hidden_state,
Expand Down
189 changes: 188 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ..prelude import Prelude, StaticTensorArrayOps
from ..ty import Any, TensorType, TupleType
from . import qnn_torch
from .common import AttrCvt, get_relay_op, gru_cell, logger
from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
Expand Down Expand Up @@ -2630,6 +2630,191 @@ def flip(self, inputs, input_types):
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])

def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
"""
Bidirectional RNN cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = rnn_cell(input_seqs, **weights_dicts[0], backwards=False, act=act)

reverse_outputs, rev_H_t = rnn_cell(input_seqs, **weights_dicts[1], backwards=True, act=act)

final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
)

return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)

def rnn_layers(self, input_data, layer_weights_dicts, bidirectional, act, dropout_p=0.0):
"""
Methods iterates layers for Stacked RNN
"""
layers_num = len(layer_weights_dicts)
# split input sequence to samples set
input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)]
output_hiddens = []
for i in range(layers_num):
weights_dicts = layer_weights_dicts[i]
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
if bidirectional:
input_seqs, H_t = self.bidir_rnn_cell(input_seqs, weights_dicts, act=act)
else:
input_seqs, H_t = rnn_cell(input_seqs, **weights_dicts[0], act=act)

output_hiddens.append(H_t)

# TODO (yuanfz98): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and i < layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for GRU has not been supported yet!")
output_hiddens = (
_op.concatenate(output_hiddens, 0) if bidirectional else _op.stack(output_hiddens, 0)
)
return _op.stack(input_seqs, 0), output_hiddens

def rnn(self, inputs, input_types, nonlinearity):
"""
Description of RNN in pytorch:
https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN
Description of inputs:
https://github.com/pytorch/pytorch/blob/736fb7d22cc948b739db2c35aeb5ad4d19aea4f4/torch/overrides.py#L937
"""
# TODO (yuanfz98): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)

hidden_state = inputs[1]
# Hidden state shape (hidden_layers_num, batch, hidden_size)

_weights = inputs[2]
# Wi layer[0] shape (hidden_size, feature_size)
# Wh layer[0] shape (hidden_size, hidden_size)
# Bi layer[0] shape (hidden_size)
# Bh layer[0] shape (hidden_size)

# Wi layer[>0] shape (hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (hidden_size, hidden_size)
# Bi layer[>0] shape (hidden_size)
# Bh layer[>0] shape (hidden_size)

# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]

num_directions = 1
if bidirectional:
num_directions = 2

rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"

weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"
if nonlinearity == "tanh":
act = _op.tanh
elif nonlinearity == "relu":
act = _op.nn.relu
assert act, "The nonlinearity is unknown"
X = (
_op.transpose(_X, (1, 0, 2)) if batch_first else _X
) # always (seq_num, batch, feature_size)
# TODO (yuanfz98): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)

hidden_size = int(_infer_shape(_weights[0])[0])
batch_size = X_shape[1]

# Initialize hidden states if not provided.
layers_h = []
hidden_layers_num = num_directions * num_layers
if hidden_state is None:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = unbind(hidden_state, 0)

layer_weights_dicts = []
k = 0 # layer counter
if has_biases:
names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
else:
names = ["hidden_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
assert (
len(layer_weights_dicts) == num_layers and k == num_layers
), "For stacked RNN number of weights sets should be the same as number of layers!"
output, out_hidden_state = self.rnn_layers(
X,
layer_weights_dicts,
bidirectional,
act,
dropout_p=dropout_p,
)

# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
if batch_first:
output = _op.transpose(output, (1, 0, 2))

return (output, out_hidden_state)

def bidir_gru_cell(
self,
input_seqs,
Expand Down Expand Up @@ -3442,6 +3627,8 @@ def create_convert_map(self):
"aten::l1_loss": self.l1_loss,
"aten::mse_loss": self.mse_loss,
"aten::flip": self.flip,
"aten::rnn_tanh": functools.partial(self.rnn, nonlinearity="tanh"),
"aten::rnn_relu": functools.partial(self.rnn, nonlinearity="relu"),
"aten::gru": self.gru,
"aten::lstm": self.lstm,
"aten::all": functools.partial(self.all_any_common, _op.all),
Expand Down
79 changes: 79 additions & 0 deletions tests/python/frontend/pytorch/test_rnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
seqs_length = 2
batch_size = 2

##RNN parameters
rnn_feature_size = 8
rnn_hidden_size = 16


class RNN_Model(nn.Module):
"""
Expand Down Expand Up @@ -93,6 +97,72 @@ def get_tvm_inputs(self, dtype):
raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!")


class RNN_Model_Impl(RNN_Model):
def __init__(
self,
seq_len=seqs_length,
batch_size=batch_size,
feature_size=rnn_feature_size,
hidden_size=rnn_hidden_size,
batch_first=False,
layer_num=1,
bidirectional=False,
use_bias=True,
rnd_weights_init=False,
nonlinearity="tanh",
dropout=0.0,
):
super().__init__()
# Shapes
self.shape = [seq_len, batch_size, feature_size]
if batch_first:
self.shape = [batch_size, seq_len, feature_size]
layers_num = 2 * layer_num if bidirectional else layer_num
self.h0_shape = [layers_num, batch_size, hidden_size]
# Dummy inputs
self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape))

self.model = nn.RNN(
input_size=feature_size,
hidden_size=hidden_size,
num_layers=layer_num,
nonlinearity=nonlinearity,
bias=use_bias,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional,
)

if rnd_weights_init:
self.gen_rnd_weights()

def gen_rnd_weights(self):
super().gen_rnd_weights()

def get_dummy_inputs(self):
return self.dummy_inputs

def get_input_names(self):
return ["input", "h0"]

def get_shape_desc(self, frontend_type):
shape_desc = None
if frontend_type == "pt": # PyTorch
shape_desc = [("input", self.shape)]
elif frontend_type == "onnx": # ONNX
shape_desc = {
"input": self.shape,
"h0": self.h0_shape,
}
return shape_desc

def get_tvm_inputs(self, dtype):
return {
"input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)),
"h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)),
}


class GRU_Model(RNN_Model):
def __init__(
self,
Expand Down Expand Up @@ -331,13 +401,19 @@ def get_model(
args["bidirectional"] = True
if "s" in rnn_mod:
args["layer_num"] = num_layers
if "tanh" in rnn_mod:
args["nonlinearity"] = "tanh"
if "relu" in rnn_mod:
args["nonlinearity"] = "relu"

if rnn_type == "GRU":
RNN_Model_selector = GRU_Model
elif rnn_type == "LSTM":
RNN_Model_selector = LSTM_Model
if "p" in rnn_mod:
args["proj_size"] = lstm_projection_size
elif rnn_type == "RNN":
RNN_Model_selector = RNN_Model_Impl

return RNN_Model_selector(**args)

Expand Down Expand Up @@ -425,6 +501,9 @@ def test_rnns():
for mod_type in ["uni", "s", "b", "sb"]:
check_rnn("LSTM", mod_type, target, dev)

for mod_type in ["uni", "s", "b", "sb", "tanh", "relu"]:
check_rnn("RNN", mod_type, target, dev)


if __name__ == "__main__":
test_rnns()

0 comments on commit 067ea66

Please sign in to comment.