Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[converter] add autograd func #293

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Operators that are implemented in Python
| `prim::ListConstruct` | |
| `prim::ListUnpack` | |
| `prim::NumToTensor` | |
| `prim::Param` | |
| `prim::PythonOp` | |
| `prim::Return` | |
| `prim::TupleConstruct` | |

## ATen Operators
Expand Down
25 changes: 24 additions & 1 deletion tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def unsupported_operations(self, unique=True) -> typing.List[str]:
def init_operations(self):
log.debug('Initialize operators...')
node_queue = collections.deque(self.graph.nodes())
scope_map = {}
current_scope = None
while node_queue:
node = node_queue.popleft()

Expand All @@ -399,6 +401,7 @@ def init_operations(self):
converter = converter_type(
node,
self.tensor_map,
current_scope,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
Expand Down Expand Up @@ -429,7 +432,18 @@ def init_operations(self):
else:
converter_type = NoTrackOperator
converter = converter_type(
node, self.tensor_map, not self.strict_symmetric_check, self.q_type, self.hybrid_q_type
node,
self.tensor_map,
current_scope,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
self.map_bilstm_to_lstm,
self.enable_mtk_ops,
self.hybrid_asymmetric_inputs,
self.unroll_rnn,
self.separated_rnn_gate_calc,
self.conv_transpose_with_bias,
)
if k != 'prim::Constant':
log.debug(f'{k} {converter.input_names} -> {converter.output_names} {converter_type.__name__}')
Expand All @@ -451,6 +465,15 @@ def init_operations(self):
if len(new_nodes) > 0:
node_queue.extendleft(reversed(new_nodes))

if k == 'prim::PythonOp':
s = node.scopeName()
scope_map.setdefault(s, 0)
scope_map[s] += 1
current_scope = f'{s}_{scope_map[s]}'
converter.prepare_scope_tensors(node, attrs, args, self.common_graph, current_scope)
elif k == 'prim::Return':
current_scope = None

assert len(output_tensors) == len(outputs)
for t, name in zip(output_tensors, outputs):
self.tensor_map[name] = t
Expand Down
3 changes: 3 additions & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"prim::If": PrimIfConverter,
"aten::__getitem__": PrimGetItemConverter,
"aten::len": PrimLenConverter,
"prim::Param": PrimParamConverter,
"prim::PythonOp": PrimPythonOpConverter,
"prim::Return": PrimReturnConverter,
# aten
"aten::sign": AtenSignOperator,
"aten::t": ATenTOperator,
Expand Down
17 changes: 15 additions & 2 deletions tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
self,
node,
tensor_map,
scope_name,
asymmetric=True,
q_type=np.uint8,
hybrid_q_type=np.int8,
Expand All @@ -31,6 +32,7 @@ def __init__(
separated_rnn_gate_calc=False,
conv_transpose_with_bias=True,
) -> None:
self.scope_name = scope_name
self.input_names = self.get_input_names(node)
self.output_names = self.get_output_names(node)
self.input_tensors = self.get_input_tensors(tensor_map)
Expand All @@ -53,11 +55,20 @@ def __init__(
def parse(self, node, attrs, args, graph_converter):
pass

def get_tensor_name(self, tensor_name, scope_name=None):
if scope_name is None:
scope_name = self.scope_name

if scope_name:
return f'{scope_name}_{tensor_name}'
else:
return tensor_name

def get_input_names(self, node):
return [x.debugName() for x in list(node.inputs())]
return [self.get_tensor_name(x.debugName()) for x in list(node.inputs())]

def get_output_names(self, node):
return [x.debugName() for x in list(node.outputs())]
return [self.get_tensor_name(x.debugName()) for x in list(node.outputs())]

def get_input_tensors(self, tensor_map):
input_tensors = []
Expand Down Expand Up @@ -677,6 +688,8 @@ def get_prop_from_node(node, prop, assert_type=None, return_type=False):
v = getattr(node, vk)(prop)
elif vk == 's':
v = getattr(node, vk)(prop)
elif vk == 'g':
v = getattr(node, vk)(prop)
elif vk == 't':
v = getattr(node, vk)(prop)
if v.dtype == torch.float64:
Expand Down
61 changes: 61 additions & 0 deletions tinynn/converter/operators/torch/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,64 @@ def parse(self, node, attrs, args, graph_converter):
)
else:
graph_converter.add_operator(tfl.SplitOperator([dim_tensor, input_tensor], outputs, chunks))


class PrimPythonOpConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
subgraph = attrs['Subgraph'][0]

param_node = subgraph.param_node()
return_node = subgraph.return_node()

self.output_tensors.append(node.pyobj()(*self.input_tensors, *node.scalar_args()))
self.output_nodes.append(param_node)
self.output_nodes.extend(subgraph.nodes())
self.output_nodes.append(return_node)

def prepare_scope_tensors(self, node, attrs, args, graph_converter, scope_name):
subgraph = attrs['Subgraph'][0]

# input tensors
param_node = subgraph.param_node()
input_tensors = [self.find_or_create_input(i, graph_converter) for i in range(len(self.input_tensors))]
subgraph_input_names = [self.get_tensor_name(x.debugName(), scope_name) for x in param_node.outputs()]

for name, t in zip(subgraph_input_names, input_tensors):
graph_converter.constant_mapping[name] = t

# output tensors
return_node = subgraph.return_node()
subgraph_output_names = [self.get_tensor_name(x.debugName(), scope_name) for x in return_node.inputs()]
output_tensors = self.to_tfl_tensors(self.output_names, self.output_tensors)

for name, t in zip(subgraph_output_names, output_tensors):
graph_converter.constant_mapping[name] = t


class PrimReturnConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
for i, name in enumerate(self.input_names):
assert name in graph_converter.constant_mapping
if name in graph_converter.tensor_map:
input_tensor = self.find_or_create_input(i, graph_converter)
output_tensor = graph_converter.constant_mapping[name]

inputs = [input_tensor, self.create_attr_tensor(input_tensor.shape, name=f'{name}_return_attr')]
outputs = [output_tensor]

graph_converter.add_operator(tfl.ReshapeOperator(inputs, outputs, input_tensor.shape))


class PrimParamConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
for i, name in enumerate(self.output_names):
assert name in graph_converter.constant_mapping
input_tensor = graph_converter.constant_mapping[name]
output_tensor = self.to_tfl_tensors([name], [input_tensor.tensor])[0]
self.output_tensors.append(torch.from_numpy(input_tensor.tensor))

if input_tensor.name in graph_converter.tensor_map:
inputs = [input_tensor, self.create_attr_tensor(input_tensor.shape, name=f'{name}_return_attr')]
outputs = [output_tensor]

graph_converter.add_operator(tfl.ReshapeOperator(inputs, outputs, input_tensor.shape))
Loading