Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] add aten::rnn_tanh, aten::rnn_relu #12017

Merged
merged 17 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,46 @@ def unbind(data, axis=0):
return _expr.TupleWrapper(_expr.Tuple(ret), selections)


def rnn_cell(
input_seqs, hidden_state, w_inp, w_hid, b_inp=None, b_hid=None, backwards=False, act=_op.tanh
):
"""
Common implementation of RNN cell for all frontends of TVM

Parameters
----------
input_seqs : List[relay.Expr]
The sequence of input tensors
Input tensor should be 2d while issue #8412 is not resolved
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch_size, hidden_size)
w_inp, w_hid: relay.Expr
weight matrices. shape = (hidden_size, feature_size), (hidden_size, feature_size)
b_inp, b_hid : relay.Expr
bias matrices. The same order of internal parts as for weights. shape = (1 * hidden_size)
backwards : bool
Flag for reverse pass of RNN
act : relay.op
activation function. It is tanh by default.

Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""
outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
xwt = _op.nn.dense(x_t, w_inp)
hwt = _op.nn.dense(hidden_state, w_hid)
if b_inp is not None and b_hid is not None:
xwt += b_inp
hwt += b_hid
hidden_state = act(xwt + hwt)
outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]
return outputs_list, hidden_state


def gru_cell(
input_seqs,
hidden_state,
Expand Down
189 changes: 188 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ..prelude import Prelude, StaticTensorArrayOps
from ..ty import Any, TensorType, TupleType
from . import qnn_torch
from .common import AttrCvt, get_relay_op, gru_cell, logger
from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
Expand Down Expand Up @@ -2630,6 +2630,191 @@ def flip(self, inputs, input_types):
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])

def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
"""
Bidirectional RNN cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = rnn_cell(input_seqs, **weights_dicts[0], backwards=False, act=act)

reverse_outputs, rev_H_t = rnn_cell(input_seqs, **weights_dicts[1], backwards=True, act=act)

final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
)

return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)

def rnn_layers(self, input_data, layer_weights_dicts, bidirectional, act, dropout_p=0.0):
"""
Methods iterates layers for Stacked RNN
"""
layers_num = len(layer_weights_dicts)
# split input sequence to samples set
input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)]
output_hiddens = []
for i in range(layers_num):
weights_dicts = layer_weights_dicts[i]
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
if bidirectional:
input_seqs, H_t = self.bidir_rnn_cell(input_seqs, weights_dicts, act=act)
else:
input_seqs, H_t = rnn_cell(input_seqs, **weights_dicts[0], act=act)

output_hiddens.append(H_t)

# TODO (yuanfz98): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and i < layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for GRU has not been supported yet!")
output_hiddens = (
_op.concatenate(output_hiddens, 0) if bidirectional else _op.stack(output_hiddens, 0)
)
return _op.stack(input_seqs, 0), output_hiddens

def rnn(self, inputs, input_types, nonlinearity):
"""
Description of RNN in pytorch:
https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN
Description of inputs:
https://github.com/pytorch/pytorch/blob/736fb7d22cc948b739db2c35aeb5ad4d19aea4f4/torch/overrides.py#L937
"""
# TODO (yuanfz98): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)

hidden_state = inputs[1]
# Hidden state shape (hidden_layers_num, batch, hidden_size)

_weights = inputs[2]
# Wi layer[0] shape (hidden_size, feature_size)
# Wh layer[0] shape (hidden_size, hidden_size)
# Bi layer[0] shape (hidden_size)
# Bh layer[0] shape (hidden_size)

# Wi layer[>0] shape (hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (hidden_size, hidden_size)
# Bi layer[>0] shape (hidden_size)
# Bh layer[>0] shape (hidden_size)

# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]

num_directions = 1
if bidirectional:
num_directions = 2

rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"

weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"
if nonlinearity == "tanh":
act = _op.tanh
elif nonlinearity == "relu":
act = _op.nn.relu
assert act, "The nonlinearity is unknown"
X = (
_op.transpose(_X, (1, 0, 2)) if batch_first else _X
) # always (seq_num, batch, feature_size)
# TODO (yuanfz98): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)

hidden_size = int(_infer_shape(_weights[0])[0])
batch_size = X_shape[1]

# Initialize hidden states if not provided.
layers_h = []
hidden_layers_num = num_directions * num_layers
if hidden_state is None:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = unbind(hidden_state, 0)

layer_weights_dicts = []
k = 0 # layer counter
if has_biases:
names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
else:
names = ["hidden_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
assert (
len(layer_weights_dicts) == num_layers and k == num_layers
), "For stacked RNN number of weights sets should be the same as number of layers!"
output, out_hidden_state = self.rnn_layers(
X,
layer_weights_dicts,
bidirectional,
act,
dropout_p=dropout_p,
)

# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
if batch_first:
output = _op.transpose(output, (1, 0, 2))

return (output, out_hidden_state)

def bidir_gru_cell(
self,
input_seqs,
Expand Down Expand Up @@ -3442,6 +3627,8 @@ def create_convert_map(self):
"aten::l1_loss": self.l1_loss,
"aten::mse_loss": self.mse_loss,
"aten::flip": self.flip,
"aten::rnn_tanh": functools.partial(self.rnn, nonlinearity="tanh"),
"aten::rnn_relu": functools.partial(self.rnn, nonlinearity="relu"),
"aten::gru": self.gru,
"aten::lstm": self.lstm,
"aten::all": functools.partial(self.all_any_common, _op.all),
Expand Down
79 changes: 79 additions & 0 deletions tests/python/frontend/pytorch/test_rnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
seqs_length = 2
batch_size = 2

##RNN parameters
rnn_feature_size = 8
rnn_hidden_size = 16


class RNN_Model(nn.Module):
"""
Expand Down Expand Up @@ -93,6 +97,72 @@ def get_tvm_inputs(self, dtype):
raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!")


class RNN_Model_Impl(RNN_Model):
def __init__(
self,
seq_len=seqs_length,
batch_size=batch_size,
feature_size=rnn_feature_size,
hidden_size=rnn_hidden_size,
batch_first=False,
layer_num=1,
bidirectional=False,
use_bias=True,
rnd_weights_init=False,
nonlinearity="tanh",
dropout=0.0,
):
super().__init__()
# Shapes
self.shape = [seq_len, batch_size, feature_size]
if batch_first:
self.shape = [batch_size, seq_len, feature_size]
layers_num = 2 * layer_num if bidirectional else layer_num
self.h0_shape = [layers_num, batch_size, hidden_size]
# Dummy inputs
self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape))

self.model = nn.RNN(
input_size=feature_size,
hidden_size=hidden_size,
num_layers=layer_num,
nonlinearity=nonlinearity,
bias=use_bias,
batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional,
)

if rnd_weights_init:
self.gen_rnd_weights()

def gen_rnd_weights(self):
super().gen_rnd_weights()

def get_dummy_inputs(self):
return self.dummy_inputs

def get_input_names(self):
return ["input", "h0"]

def get_shape_desc(self, frontend_type):
shape_desc = None
if frontend_type == "pt": # PyTorch
shape_desc = [("input", self.shape)]
elif frontend_type == "onnx": # ONNX
shape_desc = {
"input": self.shape,
"h0": self.h0_shape,
}
return shape_desc

def get_tvm_inputs(self, dtype):
return {
"input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)),
"h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)),
}


class GRU_Model(RNN_Model):
def __init__(
self,
Expand Down Expand Up @@ -331,13 +401,19 @@ def get_model(
args["bidirectional"] = True
if "s" in rnn_mod:
args["layer_num"] = num_layers
if "tanh" in rnn_mod:
args["nonlinearity"] = "tanh"
if "relu" in rnn_mod:
args["nonlinearity"] = "relu"

if rnn_type == "GRU":
RNN_Model_selector = GRU_Model
elif rnn_type == "LSTM":
RNN_Model_selector = LSTM_Model
if "p" in rnn_mod:
args["proj_size"] = lstm_projection_size
elif rnn_type == "RNN":
RNN_Model_selector = RNN_Model_Impl

return RNN_Model_selector(**args)

Expand Down Expand Up @@ -425,6 +501,9 @@ def test_rnns():
for mod_type in ["uni", "s", "b", "sb"]:
check_rnn("LSTM", mod_type, target, dev)

for mod_type in ["uni", "s", "b", "sb", "tanh", "relu"]:
check_rnn("RNN", mod_type, target, dev)


if __name__ == "__main__":
test_rnns()