Skip to content

Commit

Permalink
chore: update PyTreeSpec representation
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 15, 2022
1 parent 10e3fec commit f2125d2
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 51 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ While flattening a tree, it will remain in the tree structure definitions rather
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(NoneIsLeaf, *))
([None], PyTreeSpec(*, NoneIsLeaf))
```

OpTree provides a keyword argument `none_is_leaf` to determine whether to consider the `None` object as a leaf, like other opaque objects.
Expand Down
2 changes: 1 addition & 1 deletion optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class PyTreeSpec:
def __hash__(self) -> int: ...

def register_node(
type: Type[CustomTreeNode[T]],
cls: Type[CustomTreeNode[T]],
to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
) -> None: ...
20 changes: 10 additions & 10 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def tree_flatten(
>>> tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> tree_flatten(1)
([1], PyTreeSpec(*))
>>> tree_flatten(None)
([], PyTreeSpec(None))
>>> tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(NoneIsLeaf, *))
([None], PyTreeSpec(*, NoneIsLeaf))
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
Expand All @@ -113,7 +113,7 @@ def tree_flatten(
>>> tree_flatten(tree)
([2, 3, 4, 1, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)])))
>>> tree_flatten(tree, none_is_leaf=True)
([2, 3, 4, 1, None, 5], PyTreeSpec(NoneIsLeaf, OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)])))
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf))
Args:
tree: A pytree to flatten.
Expand Down Expand Up @@ -156,14 +156,14 @@ def tree_flatten_with_path(
(
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)],
[1, 2, 3, 4, None, 5],
PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *})
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
)
>>> tree_flatten_with_path(1)
([()], [1], PyTreeSpec(*))
>>> tree_flatten_with_path(None)
([], [], PyTreeSpec(None))
>>> tree_flatten_with_path(None, none_is_leaf=True)
([()], [None], PyTreeSpec(NoneIsLeaf, *))
([()], [None], PyTreeSpec(*, NoneIsLeaf))
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
Expand All @@ -180,7 +180,7 @@ def tree_flatten_with_path(
(
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)],
[2, 3, 4, 1, None, 5],
PyTreeSpec(NoneIsLeaf, OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]))
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
)
Args:
Expand Down Expand Up @@ -274,13 +274,13 @@ def tree_structure(
>>> tree_structure(tree)
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
>>> tree_structure(tree, none_is_leaf=True)
PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *})
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
>>> tree_structure(1)
PyTreeSpec(*)
>>> tree_structure(None)
PyTreeSpec(None)
>>> tree_structure(None, none_is_leaf=True)
PyTreeSpec(NoneIsLeaf, *)
PyTreeSpec(*, NoneIsLeaf)
Args:
tree: A pytree to flatten.
Expand Down Expand Up @@ -831,7 +831,7 @@ def treespec_leaf(*, none_is_leaf: bool = False) -> PyTreeSpec:
>>> treespec_leaf()
PyTreeSpec(*)
>>> treespec_leaf(none_is_leaf=True)
PyTreeSpec(NoneIsLeaf, *)
PyTreeSpec(*, NoneIsLeaf)
>>> treespec_leaf(none_is_leaf=False) == treespec_leaf(none_is_leaf=True)
False
>>> treespec_leaf() == tree_structure(1)
Expand Down Expand Up @@ -862,7 +862,7 @@ def treespec_none(*, none_is_leaf: bool = False) -> PyTreeSpec:
>>> treespec_none()
PyTreeSpec(None)
>>> treespec_none(none_is_leaf=True)
PyTreeSpec(NoneIsLeaf, *)
PyTreeSpec(*, NoneIsLeaf)
>>> treespec_none(none_is_leaf=False) == treespec_none(none_is_leaf=True)
False
>>> treespec_none() == tree_structure(None)
Expand Down
22 changes: 11 additions & 11 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import optree._C as _C
from optree.typing import KT, VT, Children, CustomTreeNode, DefaultDict, MetaData
from optree.typing import OrderedDict as GenericOrderedDict
from optree.typing import PyTree, PyTreeSpec, T
from optree.typing import PyTree, T
from optree.utils import safe_zip, unzip2


Expand All @@ -46,28 +46,28 @@


def register_pytree_node(
type: Type[CustomTreeNode[T]], # pylint: disable=redefined-builtin
cls: Type[CustomTreeNode[T]],
flatten_func: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
unflatten_func: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
) -> Type[CustomTreeNode[T]]:
"""Extends the set of types that are considered internal nodes in pytrees.
Args:
type: A Python type to treat as an internal pytree node.
flatten_func: A function to be used during flattening, taking a value of type ``type`` and
cls: A Python type to treat as an internal pytree node.
flatten_func: A function to be used during flattening, taking an instance of ``cls`` and
returning a triple or optionally a pair, with (1) an iterable for the children to be
flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec
and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree
path entries to the corresponding children. If the entries are not provided or given by
:data:`None`, then `range(len(children))` will be used.
unflatten_func: A function taking two arguments: the auxiliary data that was returned by
``flatten_func`` and stored in the treespec, and the unflattened children. The function
should return an instance of ``type``.
should return an instance of ``cls``.
"""
_C.register_node(type, flatten_func, unflatten_func)
CustomTreeNode.register(type) # pylint: disable=no-member
_nodetype_registry[type] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func)
return type
_C.register_node(cls, flatten_func, unflatten_func)
CustomTreeNode.register(cls) # pylint: disable=no-member
_nodetype_registry[cls] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func)
return cls


def register_pytree_node_class(cls: Type[CustomTreeNode[T]]) -> Type[CustomTreeNode[T]]:
Expand Down Expand Up @@ -347,11 +347,11 @@ def pprint(self) -> str:


def register_keypaths(
type: Type[CustomTreeNode[T]], # pylint: disable=redefined-builtin
cls: Type[CustomTreeNode[T]],
handler: KeyPathHandler,
) -> KeyPathHandler:
"""Registers a key path handler for a custom pytree node type."""
_keypath_registry[type] = handler
_keypath_registry[cls] = handler
return handler


Expand Down
2 changes: 1 addition & 1 deletion src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void BuildModule(py::module& mod) { // NOLINT
&PyTreeTypeRegistry::Register,
"Register a Python type. Extends the set of types that are considered internal nodes "
"in pytrees.",
py::arg("type"),
py::arg("cls"),
py::arg("to_iterable"),
py::arg("from_iterable"));

Expand Down
2 changes: 1 addition & 1 deletion src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ std::string PyTreeSpec::ToString() const {
if (agenda.size() != 1) [[unlikely]] {
throw std::logic_error("PyTreeSpec traversal did not yield a singleton.");
}
return absl::StrCat("PyTreeSpec(", (none_is_leaf ? "NoneIsLeaf, " : ""), agenda.back(), ")");
return absl::StrCat("PyTreeSpec(", agenda.back(), (none_is_leaf ? ", NoneIsLeaf" : ""), ")");
}

py::object PyTreeSpec::ToPicklable() const {
Expand Down
50 changes: 25 additions & 25 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def tree_flatten(self):
def tree_unflatten(cls, metadata, children):
if not optree.all_leaves(children):
children, metadata = optree.tree_flatten(optree.tree_unflatten(metadata, children))
return FlatCache(structured=None, leaves=children, treespec=metadata)
return cls(structured=None, leaves=children, treespec=metadata)


@optree.register_pytree_node_class
Expand All @@ -124,7 +124,7 @@ def tree_flatten(self):

@classmethod
def tree_unflatten(cls, metadata, children):
return MyDict(zip(metadata, children))
return cls(zip(metadata, children))

def __repr__(self):
return f'{self.__class__.__name__}({super().__repr__()})'
Expand Down Expand Up @@ -243,29 +243,29 @@ def __repr__(self):
)

TREE_STRINGS_NONE_IS_LEAF = (
'PyTreeSpec(NoneIsLeaf, *)',
'PyTreeSpec(NoneIsLeaf, *)',
'PyTreeSpec(NoneIsLeaf, (*,))',
'PyTreeSpec(NoneIsLeaf, (*, *))',
'PyTreeSpec(NoneIsLeaf, ())',
'PyTreeSpec(NoneIsLeaf, [()])',
'PyTreeSpec(NoneIsLeaf, (*, *))',
'PyTreeSpec(NoneIsLeaf, ((*, *), [*, (*, *, *)]))',
'PyTreeSpec(NoneIsLeaf, [*])',
"PyTreeSpec(NoneIsLeaf, [*, CustomTuple(foo=(*, CustomTuple(foo=*, bar=*)), bar={'baz': *})])",
"PyTreeSpec(NoneIsLeaf, [CustomTreeNode(Vector3D[[4, 'foo']], [*, *])])",
'PyTreeSpec(NoneIsLeaf, CustomTreeNode(Vector2D[None], [*, *]))',
"PyTreeSpec(NoneIsLeaf, {'a': *, 'b': *})",
"PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', *)]))",
"PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *, *]))]))",
"PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]))",
"PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]))",
"PyTreeSpec(NoneIsLeaf, defaultdict(<class 'dict'>, {'baz': *, 'foo': *, 'something': *}))",
"PyTreeSpec(NoneIsLeaf, CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *]), *]))",
'PyTreeSpec(NoneIsLeaf, CustomNamedTupleSubclass(foo=*, bar=*))',
'PyTreeSpec(NoneIsLeaf, CustomTreeNode(FlatCache[PyTreeSpec(None)], []))',
'PyTreeSpec(NoneIsLeaf, CustomTreeNode(FlatCache[PyTreeSpec(*)], [*]))',
"PyTreeSpec(NoneIsLeaf, CustomTreeNode(FlatCache[PyTreeSpec({'a': [*, *]})], [*, *]))",
'PyTreeSpec(*, NoneIsLeaf)',
'PyTreeSpec(*, NoneIsLeaf)',
'PyTreeSpec((*,), NoneIsLeaf)',
'PyTreeSpec((*, *), NoneIsLeaf)',
'PyTreeSpec((), NoneIsLeaf)',
'PyTreeSpec([()], NoneIsLeaf)',
'PyTreeSpec((*, *), NoneIsLeaf)',
'PyTreeSpec(((*, *), [*, (*, *, *)]), NoneIsLeaf)',
'PyTreeSpec([*], NoneIsLeaf)',
"PyTreeSpec([*, CustomTuple(foo=(*, CustomTuple(foo=*, bar=*)), bar={'baz': *})], NoneIsLeaf)",
"PyTreeSpec([CustomTreeNode(Vector3D[[4, 'foo']], [*, *])], NoneIsLeaf)",
'PyTreeSpec(CustomTreeNode(Vector2D[None], [*, *]), NoneIsLeaf)',
"PyTreeSpec({'a': *, 'b': *}, NoneIsLeaf)",
"PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', *)]), NoneIsLeaf)",
"PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *, *]))]), NoneIsLeaf)",
"PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]), NoneIsLeaf)",
"PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]), NoneIsLeaf)",
"PyTreeSpec(defaultdict(<class 'dict'>, {'baz': *, 'foo': *, 'something': *}), NoneIsLeaf)",
"PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *]), *]), NoneIsLeaf)",
'PyTreeSpec(CustomNamedTupleSubclass(foo=*, bar=*), NoneIsLeaf)',
'PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec(None)], []), NoneIsLeaf)',
'PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec(*)], [*]), NoneIsLeaf)',
"PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec({'a': [*, *]})], [*, *]), NoneIsLeaf)",
)

TREE_STRINGS = {
Expand Down

0 comments on commit f2125d2

Please sign in to comment.