Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Converter] Support GRU operator conversion with separated_rnn_gate_calc=False #323

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ You may also try out static quantization for LSTMs when you have PyTorch 1.13+.
#### What if my model runs slower when dynamic quantization is enabled?
Please refer to [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) for selective dynamic quantization.

#### I need LSTMs with separated gate calculation when `unroll_rnn=True`.
#### I need LSTM/GRUs with separated gate calculation when `unroll_rnn=True`.
Please set `separated_rnn_gate_calc=True`.

#### How to add state inputs and outputs for LSTMs/GRUs/RNNs with `unroll_rnn=True`?
Expand Down
2 changes: 1 addition & 1 deletion docs/FAQ_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Note: 这些状态变量都是二维的,维度为`[batch_size, hidden_size或
#### 我的模型开了动态量化变得更慢了?
请参考 [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) 选择性的开启动态量化。

#### 在设置了`unroll_rnn=True`后,LSTM中多个门的计算被融合了。有没有办法分开?
#### 在设置了`unroll_rnn=True`后,LSTM/GRU中多个门的计算被融合了。有没有办法分开?
尝试设置`separated_rnn_gate_calc=True`。

#### 在`unroll_rnn=True`的情况下,怎么为包含LSTM、RNN和GRU的网络添加状态输入输出?
Expand Down
176 changes: 176 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2952,6 +2952,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_batch_first_unroll_separated(self):
dummy_input = torch.randn(1, 9, 10, dtype=torch.float32)

Expand All @@ -2976,6 +3000,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_gru_batch_first_unroll_unseparated(self):
dummy_input = torch.randn(1, 9, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, batch_first=True)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_gru_with_state_tensor_unroll_separated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
Expand Down Expand Up @@ -3004,6 +3052,34 @@ def forward(self, x, hx):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_with_state_tensor_unroll_unseparated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
torch.randn(1, 1, 20, dtype=torch.float32),
]

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20)

def forward(self, x, hx):
gru, hx = self.gru(x, hx)
return gru, hx

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_unroll_separated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand All @@ -3028,6 +3104,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, 2)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_with_state_tensor_unroll_separated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
Expand Down Expand Up @@ -3056,6 +3156,34 @@ def forward(self, x, hx):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_with_state_tensor_unroll_unseparated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
torch.randn(2, 1, 20, dtype=torch.float32),
]

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, 2)

def forward(self, x, hx):
gru, hx = self.gru(x, hx)
return gru, hx

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down Expand Up @@ -3278,6 +3406,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, bidirectional=True)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_multi_layer_unroll_separated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand All @@ -3302,6 +3454,30 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_multi_layer_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gru = nn.GRU(10, 20, 2, bidirectional=True)

def forward(self, x):
return self.gru(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_lstm(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down
Loading
Loading