-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support the List/Tuple Construct/Unpack operation for TorchModuleGraph #2609
Conversation
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
@@ -199,6 +203,8 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None | |||
All the inputs of this node, each element is debugName of one input | |||
outputs: list of str | |||
All the outputs of this node, each element is debugName of one output | |||
key_node: torch._C.Node | |||
The key node of this NodePyGroup. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of key node?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Key nodes are the nodes that should not be merged into other nodes. In the past, we only take the aten::
nodes as the important(key) nodes. In this pr, we also take the list/tuple unpack nodes as the key nodes.
# the nodes that start with 'aten' are key function | ||
# nodes | ||
return True | ||
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not construct type here?
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]: | ||
# We cannot merge the List/Tuple | ||
# Construct/Unpack func into other nodes, else it | ||
# may lead to a graph construction error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this error like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take the shufflenet as an example:
#2581
for node in self.nodes_py.nodes_op: | ||
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]: | ||
unpack_cpp = node.key_node | ||
last_cpp = list(unpack_cpp.inputs())[0].node() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index is 0, why call it last
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last
node actually refers to the previous(last) visited node. In most scenarios, this last_cpp
is the corresponding construct node of the tuple/list.
# or list manunally. | ||
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp)) | ||
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp)) | ||
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a little confused about this assert, i think the main reason is i don't understand what is last_cpp
Support the list/tuple construct/unpack operation for the TorchModuleGraph.
Fix the bug mentioned in #2581.
In the original version, we take the list/tuple construct/unpack operation nodes as unimportant nodes and merge them with the adjacent important nodes. However, merging the unpack nodes will lead to a graph construct error.