Skip to content

Commit

Permalink
[converter] fix missing_outputs_as_constants
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 committed Apr 8, 2024
1 parent 6b62859 commit 1f04964
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 23 deletions.
19 changes: 19 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def u8_to_s8(t):


class ConverterOPTester(unittest.TestCase):
def test_missing_outputs_as_constants(self):
class TestModel(nn.Module):
def forward(self, x):
y = x.relu()
return y, torch.zeros_like(x)

model = TestModel()
model.eval()

dummy_input = torch.randn(2, 10)
model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, missing_outputs_as_constants=True)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_sign(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down
19 changes: 18 additions & 1 deletion tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.graph = None
self.tensor_map = {}
self.tensor_map_copies = {}
self.common_graph = CommonGraph(missing_outputs_as_constants)
self.common_graph = CommonGraph()

if type(dummy_input) in (tuple, list):
self.dummy_input = dummy_input
Expand Down Expand Up @@ -166,6 +166,7 @@ def __init__(
self.hybrid_gen_single_op_models = hybrid_gen_single_op_models
self.hybrid_config = hybrid_config
self.group_tensors = group_tensors
self.missing_outputs_as_constants = missing_outputs_as_constants

if quantize_target_type == 'uint8':
self.q_type = np.uint8
Expand Down Expand Up @@ -526,6 +527,22 @@ def convert(self):
versioner = OPVersioner(self.common_graph)
versioner.process()

if self.missing_outputs_as_constants:
tensors = []
for output_name in self.common_graph.outputs:
if output_name not in self.common_graph.tensor_map:
tensors.append(
Tensor(
self.tensor_map[output_name],
output_name,
has_buffer=True,
asymmetric=not self.strict_symmetric_check,
q_type=self.q_type,
)
)
self.common_graph.add_nodes(tensors, ExtendedOperator.CONSTANT_NODE)
self.common_graph.add_outputs([t.name for t in tensors])

self.common_graph.convert(self.tflite_path)

log.info(f'Generated model saved to {self.tflite_path}')
Expand Down
34 changes: 12 additions & 22 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ class CommonGraph(object):
input_transpose: typing.List[bool]
output_transpose: typing.Union[typing.List[typing.Optional[bool]], typing.Optional[bool]]
node_op_counter: int
missing_outputs_as_constants: bool

def __init__(self, missing_outputs_as_constants: bool) -> None:
def __init__(self) -> None:
self.graph = ig.Graph(directed=True)
self.tensor_map = dict()
self.tensor_node_map = dict()
Expand All @@ -40,8 +39,6 @@ def __init__(self, missing_outputs_as_constants: bool) -> None:
self.transform_store = {}
self.constant_mapping = {}

self.missing_outputs_as_constants = missing_outputs_as_constants

def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
self.transform_store[tensor_name][transform_name] = new_tensor_name
Expand Down Expand Up @@ -640,24 +637,17 @@ def collect_tensor_buffers(
missing_inputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(inputs, input_idx))]
missing_outputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(outputs, output_idx))]

if not self.missing_outputs_as_constants:
assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}'

missing_vars_dict = {
'input': (missing_inputs, inputs, input_idx),
'output': (missing_outputs, outputs, output_idx),
}

for key, (missing_vars, var_indices, out_indices) in missing_vars_dict.items():
if len(missing_vars) != 0:
warnings.warn(f'Some {key} nodes are missing: {missing_vars}, will try to add them into graph')
for name in missing_vars:
tensor = self.tensor_map[name]
tensor.index = tensor_idx
tensor_idx += 1
tensors.append(tensor)
item_idx = var_indices.index(name)
out_indices[item_idx] = tensor.index
assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}'

if len(missing_inputs) != 0:
warnings.warn(f'Some input nodes are missing: {missing_inputs}, will try to add them into graph')
for name in missing_inputs:
tensor = self.tensor_map[name]
tensor.index = tensor_idx
tensor_idx += 1
tensors.append(tensor)
item_idx = inputs.index(name)
input_idx[item_idx] = tensor.index

return tensors, buffers, input_idx, output_idx

Expand Down

0 comments on commit 1f04964

Please sign in to comment.