diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d5a135f8..f8a4026a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,7 +85,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.6", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] fail-fast: false timeout-minutes: 30 steps: diff --git a/.pylintrc b/.pylintrc index 1296d79e..d66f9c94 100644 --- a/.pylintrc +++ b/.pylintrc @@ -53,7 +53,7 @@ persistent=yes # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.6 +py-version=3.7 # Discover python modules and packages in the file system subtree. recursive=no diff --git a/README.md b/README.md index a6d7ffd7..976d3be9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # OpTree -![Python 3.6+](https://img.shields.io/badge/Python-3.6%2B-brightgreen) +![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen) [![PyPI](https://img.shields.io/pypi/v/optree?logo=pypi)](https://pypi.org/project/optree) ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/optree/Build?label=build&logo=github) ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/optree/Tests?label=tests&logo=github) @@ -61,7 +61,7 @@ cd optree pip3 install . ``` -Compiling from the source requires Python 3.6+, a compiler (`gcc` / `clang` / `icc` / `cl.exe`) supports C++20 and a `cmake` installation. +Compiling from the source requires Python 3.7+, a compiler (`gcc` / `clang` / `icc` / `cl.exe`) supports C++20 and a `cmake` installation. -------------------------------------------------------------------------------- @@ -122,17 +122,19 @@ The `NoneType` is a special case discussed in section [`None` is non-leaf Node v A container-like Python type can be registered in the container registry with a pair of functions that specify: -- `flatten_func(container) -> (children, metadata)`: convert an instance of the container type to a `(children, metadata)` pair, where `children` is an iterable of subtrees. +- `flatten_func(container) -> (children, metadata, entries)`: convert an instance of the container type to a `(children, metadata, entries)` triple, where `children` is an iterable of subtrees and `entries` is an iterable of path entries of the container (e.g., indices or keys). - `unflatten_func(metadata, children) -> container`: convert such a pair back to an instance of the container type. The `metadata` is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values). +The `entries` can be omitted (only returns a pair) or is optional to implement (returns `None`). If so, use `range(len(children))` (i.e., flat indices) as path entries of the current node. The function signature can be `flatten_func(container) -> (children, metadata)` or `flatten_func(container) -> (children, metadata, None)`. + ```python >>> import torch >>> optree.register_pytree_node( ... torch.Tensor, -... flatten_func=lambda tensor: ( +... flatten_func=lambda tensor: ( # -> (children, metadata) ... (tensor.cpu().numpy(),), ... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad), ... ), @@ -154,6 +156,9 @@ The `metadata` is some necessary data apart from the children to reconstruct the }) ) +>>> optree.tree_paths(tree) # entries are not defined and use `range(len(children))` +[('bias', 0), ('weight', 0)] + >>> optree.tree_unflatten(treespec, leaves) {'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')} ``` @@ -165,16 +170,24 @@ Users can also extend the pytree registry by decorating the custom class and def ... ... @optree.register_pytree_node_class ... class MyDict(UserDict): -... def tree_flatten(self): +... def tree_flatten(self): # -> (children, metadata, entries) ... reversed_keys = sorted(self.keys(), reverse=True) -... return [self[key] for key in reversed_keys], reversed_keys +... return ( +... [self[key] for key in reversed_keys], # children +... reversed_keys, # metadata +... reversed_keys, # entries +... ) ... ... @classmethod ... def tree_unflatten(cls, metadata, children): ... return MyDict(zip(metadata, children)) ->>> optree.tree_flatten(MyDict(b=2, a=1, c=3)) -([3, 2, 1], PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *]))) +>>> optree.tree_flatten_with_path(MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))) +( + [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)], + [6, 5, 4, 2, 3], + PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)])) +) ``` #### Limitations of the PyTree Type Registry @@ -195,7 +208,7 @@ It may also serve as a sentinel value. In addition, if a function has returned without any return value, it also implicitly returns the `None` object. By default, the `None` object is considered a non-leaf node in the tree with arity 0, i.e., _**a non-leaf node that has no children**_. -This is slightly different than the definition of a non-leaf node as discussed above. +This is like the behavior of an empty tuple. While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list. ```python @@ -203,13 +216,13 @@ While flattening a tree, it will remain in the tree structure definitions rather >>> optree.tree_flatten(tree) ([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})) >>> optree.tree_flatten(tree, none_is_leaf=True) -([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *})) +([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) >>> optree.tree_flatten(1) ([1], PyTreeSpec(*)) >>> optree.tree_flatten(None) ([], PyTreeSpec(None)) >>> optree.tree_flatten(None, none_is_leaf=True) -([None], PyTreeSpec(NoneIsLeaf, *)) +([None], PyTreeSpec(*, NoneIsLeaf)) ``` OpTree provides a keyword argument `none_is_leaf` to determine whether to consider the `None` object as a leaf, like other opaque objects. diff --git a/include/registry.h b/include/registry.h index dce3f312..b89f0679 100644 --- a/include/registry.h +++ b/include/registry.h @@ -52,7 +52,7 @@ class PyTreeTypeRegistry { // The following values are populated for custom types. // The Python type object, used to identify the type. py::object type; - // A function with signature: object -> (iterable, metadata) + // A function with signature: object -> (iterable, metadata, entries) py::function to_iterable; // A function with signature: (metadata, iterable) -> object py::function from_iterable; diff --git a/include/treespec.h b/include/treespec.h index e6b9d616..8380fef0 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -28,6 +28,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -64,6 +65,26 @@ class PyTreeSpec { const std::optional &leaf_predicate = std::nullopt, const bool &none_is_leaf = false); + // Flattens a PyTree into a list of leaves with a list of paths and a PyTreeSpec. + // Returns references to the flattened objects, which might be temporary objects in the case of + // custom PyType handlers. + static std::tuple, std::vector, std::unique_ptr> + FlattenWithPath(const py::handle &tree, + const std::optional &leaf_predicate = std::nullopt, + const bool &none_is_leaf = false); + + // Recursive helper used to implement FlattenWithPath(). + void FlattenIntoWithPath(const py::handle &handle, + std::vector &leaves, // NOLINT + std::vector &paths, // NOLINT + const std::optional &leaf_predicate = std::nullopt, + const bool &none_is_leaf = false); + void FlattenIntoWithPath(const py::handle &handle, + absl::InlinedVector &leaves, // NOLINT + absl::InlinedVector &paths, // NOLINT + const std::optional &leaf_predicate = std::nullopt, + const bool &none_is_leaf = false); + // Flattens a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure // of 'x'. For example, if we flatten a value [(1, (2, 3)), {"foo": 4}] with a PyTreeSpec [(*, // *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}]. @@ -146,6 +167,13 @@ class PyTreeSpec { // For a Custom type, contains the auxiliary data returned by the `to_iterable` function. py::object node_data; + // The tuple of path entries. + // This is optional, if not specified, `range(arity)` is used. + // For a sequence, contains the index of the element. + // For a mapping, contains the key of the element. + // For a Custom type, contains the path entries returned by the `to_iterable` function. + py::object node_entries; + // Custom type registration. Must be null for non-custom types. const PyTreeTypeRegistry::Registration *custom = nullptr; @@ -176,6 +204,14 @@ class PyTreeSpec { Span &leaves, // NOLINT const std::optional &leaf_predicate); + template + void FlattenIntoWithPathImpl(const py::handle &handle, + Span &leaves, // NOLINT + Span &paths, // NOLINT + Stack &stack, // NOLINT + const ssize_t &depth, + const std::optional &leaf_predicate); + py::list FlattenUpToImpl(const py::handle &full_tree) const; template diff --git a/include/utils.h b/include/utils.h index 104d6d2f..952a81b1 100644 --- a/include/utils.h +++ b/include/utils.h @@ -28,6 +28,7 @@ limitations under the License. #include #include #include +#include #include #include diff --git a/optree/_C.pyi b/optree/_C.pyi index 093699a6..0306d7a2 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -16,10 +16,10 @@ # pylint: disable=all # isort: off -from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Sequence, Tuple, Type +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type if TYPE_CHECKING: - from optree.typing import MetaData, Children, CustomTreeNode, PyTree, T, U + from optree.typing import Children, CustomTreeNode, MetaData, PyTree, T, U version: int @@ -28,6 +28,11 @@ def flatten( leaf_predicate: Optional[Callable[[T], bool]] = None, node_is_leaf: bool = False, ) -> Tuple[List[T], 'PyTreeSpec']: ... +def flatten_with_path( + tree: PyTree[T], + leaf_predicate: Optional[Callable[[T], bool]] = None, + node_is_leaf: bool = False, +) -> Tuple[List[Tuple[Any, ...]], List[T], 'PyTreeSpec']: ... def all_leaves(iterable: Iterable[T], node_is_leaf: bool = False) -> bool: ... def leaf(node_is_leaf: bool = False) -> 'PyTreeSpec': ... def none(node_is_leaf: bool = False) -> 'PyTreeSpec': ... @@ -53,7 +58,7 @@ class PyTreeSpec: def __hash__(self) -> int: ... def register_node( - type: Type[CustomTreeNode[T]], + cls: Type[CustomTreeNode[T]], to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]], from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]], ) -> None: ... diff --git a/optree/__init__.py b/optree/__init__.py index 79b12a23..8dbaa22e 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -22,9 +22,13 @@ tree_all, tree_any, tree_flatten, + tree_flatten_with_path, tree_leaves, tree_map, tree_map_, + tree_map_with_path, + tree_map_with_path_, + tree_paths, tree_reduce, tree_replace_nones, tree_structure, @@ -60,12 +64,16 @@ 'NONE_IS_NODE', 'NONE_IS_LEAF', 'tree_flatten', + 'tree_flatten_with_path', 'tree_unflatten', 'tree_leaves', 'tree_structure', + 'tree_paths', 'all_leaves', 'tree_map', 'tree_map_', + 'tree_map_with_path', + 'tree_map_with_path_', 'tree_reduce', 'tree_transpose', 'tree_replace_nones', diff --git a/optree/ops.py b/optree/ops.py index 54600c02..9239da4e 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -14,6 +14,8 @@ # ============================================================================== """OpTree: Optimized PyTree Utilities.""" +# pylint: disable=too-many-lines + import difflib import functools import textwrap @@ -49,12 +51,16 @@ 'NONE_IS_NODE', 'NONE_IS_LEAF', 'tree_flatten', + 'tree_flatten_with_path', 'tree_unflatten', 'tree_leaves', 'tree_structure', + 'tree_paths', 'all_leaves', 'tree_map', 'tree_map_', + 'tree_map_with_path', + 'tree_map_with_path_', 'tree_reduce', 'tree_transpose', 'tree_replace_nones', @@ -82,6 +88,8 @@ def tree_flatten( ) -> Tuple[List[T], PyTreeSpec]: """Flattens a pytree. + See also :func:`tree_flatten_with_path`. + The flattening order (i.e., the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal. @@ -89,20 +97,23 @@ def tree_flatten( >>> tree_flatten(tree) ([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})) >>> tree_flatten(tree, none_is_leaf=True) - ([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *})) + ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) >>> tree_flatten(1) ([1], PyTreeSpec(*)) >>> tree_flatten(None) ([], PyTreeSpec(None)) >>> tree_flatten(None, none_is_leaf=True) - ([None], PyTreeSpec(NoneIsLeaf, *)) + ([None], PyTreeSpec(*, NoneIsLeaf)) For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the insertion order. >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) + >>> tree_flatten(tree) ([2, 3, 4, 1, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))) + >>> tree_flatten(tree, none_is_leaf=True) + ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)) Args: tree: A pytree to flatten. @@ -115,12 +126,81 @@ def tree_flatten( in the leaves list. (default: :data:`False`) Returns: - A pair where the first element is a list of leaf values and the second element is a treespec - representing the structure of the pytree. + A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the + second element is a treespec representing the structure of the pytree. """ return _C.flatten(tree, is_leaf, none_is_leaf) +def tree_flatten_with_path( + tree: PyTree[T], + is_leaf: Optional[Callable[[T], bool]] = None, + *, + none_is_leaf: bool = False, +) -> Tuple[List[Tuple[Any, ...]], List[T], PyTreeSpec]: + """Flattens a pytree and additionally records the paths. + + See also :func:`tree_flatten` and :func:`tree_paths`. + + The flattening order (i.e., the order of elements in the output list) is deterministic, + corresponding to a left-to-right depth-first tree traversal. + + >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} + >>> tree_flatten_with_path(tree) + ( + [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)], + [1, 2, 3, 4, 5], + PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) + ) + >>> tree_flatten_with_path(tree, none_is_leaf=True) + ( + [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)], + [1, 2, 3, 4, None, 5], + PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) + ) + >>> tree_flatten_with_path(1) + ([()], [1], PyTreeSpec(*)) + >>> tree_flatten_with_path(None) + ([], [], PyTreeSpec(None)) + >>> tree_flatten_with_path(None, none_is_leaf=True) + ([()], [None], PyTreeSpec(*, NoneIsLeaf)) + + For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is + dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` + if you want to keep the keys in the insertion order. + + >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) + >>> tree_flatten_with_path(tree) + ( + [('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('d',)], + [2, 3, 4, 1, 5], + PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)])) + ) + >>> tree_flatten_with_path(tree, none_is_leaf=True) + ( + [('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)], + [2, 3, 4, 1, None, 5], + PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf) + ) + + Args: + tree: A pytree to flatten. + is_leaf: An optionally specified function that will be called at each flattening step. It + should return a boolean, with :data:`True` stopping the traversal and the whole subtree + being treated as a leaf, and :data:`False` indicating the flattening should traverse the + current object. + none_is_leaf: Whether to treat :data:`None` as a leaf. If :data:`False`, :data:`None` is a + non-leaf node with arity 0. Thus :data:`None` is contained in the treespec rather than + in the leaves list. (default: :data:`False`) + + Returns: + A triple ``(paths, leaves, treespec)``. The first element is a list of the paths to the leaf + values, while each path is a tuple of the index or keys. The second element is a list of + leaf values and the last element is a treespec representing the structure of the pytree. + """ + return _C.flatten_with_path(tree, is_leaf, none_is_leaf) + + def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[T]) -> PyTree[T]: """Reconstructs a pytree from the treespec and the leaves. @@ -155,7 +235,7 @@ def tree_leaves( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_leaves(tree) [1, 2, 3, 4, 5] - >>> tree_leaves(tree) + >>> tree_leaves(tree, none_is_leaf=True) [1, 2, 3, 4, None, 5] >>> tree_leaves(1) [1] @@ -194,13 +274,13 @@ def tree_structure( >>> tree_structure(tree) PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) >>> tree_structure(tree, none_is_leaf=True) - PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}) + PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) >>> tree_structure(1) PyTreeSpec(*) >>> tree_structure(None) PyTreeSpec(None) >>> tree_structure(None, none_is_leaf=True) - PyTreeSpec(NoneIsLeaf, *) + PyTreeSpec(*, NoneIsLeaf) Args: tree: A pytree to flatten. @@ -218,6 +298,44 @@ def tree_structure( return _C.flatten(tree, is_leaf, none_is_leaf)[1] +def tree_paths( + tree: PyTree[T], + is_leaf: Optional[Callable[[T], bool]] = None, + *, + none_is_leaf: bool = False, +) -> List[Tuple[Any, ...]]: + """Gets the path entries to the leaves of a pytree. + + See also :func:`tree_flatten` and :func:`tree_flatten_with_path`. + + >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} + >>> tree_paths(tree) + [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)] + >>> tree_paths(tree, none_is_leaf=True) + [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)] + >>> tree_paths(1) + [()] + >>> tree_paths(None) + [] + >>> tree_paths(None, none_is_leaf=True) + [()] + + Args: + tree: A pytree to flatten. + is_leaf: An optionally specified function that will be called at each flattening step. It + should return a boolean, with :data:`True` stopping the traversal and the whole subtree + being treated as a leaf, and :data:`False` indicating the flattening should traverse the + current object. + none_is_leaf: Whether to treat :data:`None` as a leaf. If :data:`False`, :data:`None` is a + non-leaf node with arity 0. Thus :data:`None` is contained in the treespec rather than + in the leaves list. (default: :data:`False`) + + Returns: + A list of the paths to the leaf values, while each path is a tuple of the index or keys. + """ + return _C.flatten_with_path(tree, is_leaf, none_is_leaf)[0] + + def all_leaves( iterable: Iterable[T], is_leaf: Optional[Callable[[T], bool]] = None, @@ -278,7 +396,7 @@ def tree_map( ) -> PyTree[U]: """Maps a multi-input function over pytree args to produce a new pytree. - See also :func:`tree_map_`, :func:`tree_flatten`, :func:`tree_leaves`, and :func:`tree_unflatten`. + See also :func:`tree_map_`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`. >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) {'x': 8, 'y': (43, 65)} @@ -331,7 +449,7 @@ def tree_map_( ) -> PyTree[T]: """Likes :func:`tree_map`, but does an inplace call on each leaf and returns the original tree. - See also :func:`tree_map`. + See also :func:`tree_map`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`. Args: func: A function that takes ``1 + len(rests)`` arguments, to be applied at the corresponding @@ -361,6 +479,93 @@ def tree_map_( return tree +def tree_map_with_path( + func: Callable[..., U], + tree: PyTree[T], + *rests: PyTree[S], + is_leaf: Optional[Callable[[T], bool]] = None, + none_is_leaf: bool = False, +) -> PyTree[U]: + """Maps a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path_`. + + >>> tree_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)}) + {'x': (1, 7), 'y': ((2, 42), (2, 64))} + >>> tree_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None}) + {'x': 8, 'y': (44, 66), 'z': None} + >>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}) + {'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}} + >>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True) + {'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}} + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the corresponding + leaves of the pytrees with extra paths. + tree: A pytree to be mapped over, with each leaf providing the second positional argument + and the corresponding path providing the first positional argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as ``tree`` or has ``tree`` + as a prefix. + is_leaf: An optionally specified function that will be called at each flattening step. It + should return a boolean, with :data:`True` stopping the traversal and the whole subtree + being treated as a leaf, and :data:`False` indicating the flattening should traverse the + current object. + none_is_leaf: Whether to treat :data:`None` as a leaf. If :data:`False`, :data:`None` is a + non-leaf node with arity 0. Thus :data:`None` is contained in the treespec rather than + in the leaves list and :data:`None` will be remain in the result pytree. (default: + :data:`False`) + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(p, x, *xs)`` where ``(p, x)`` are the path and value at the corresponding leaf in + ``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ + paths, leaves, treespec = tree_flatten_with_path(tree, is_leaf, none_is_leaf=none_is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + flat_results = map(func, paths, *flat_args) + return treespec.unflatten(flat_results) + + +def tree_map_with_path_( + func: Callable[..., Any], + tree: PyTree[T], + *rests: PyTree[S], + is_leaf: Optional[Callable[[T], bool]] = None, + none_is_leaf: bool = False, +) -> PyTree[T]: + """Likes :func:`tree_map_with_path_`, but does an inplace call on each leaf and returns the original tree. + + See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path`. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the corresponding + leaves of the pytrees with extra paths. + tree: A pytree to be mapped over, with each leaf providing the second positional argument + and the corresponding path providing the first positional argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as ``tree`` or has ``tree`` + as a prefix. + is_leaf: An optionally specified function that will be called at each flattening step. It + should return a boolean, with :data:`True` stopping the traversal and the whole subtree + being treated as a leaf, and :data:`False` indicating the flattening should traverse the + current object. + none_is_leaf: Whether to treat :data:`None` as a leaf. If :data:`False`, :data:`None` is a + non-leaf node with arity 0. Thus :data:`None` is contained in the treespec rather than + in the leaves list and :data:`None` will be remain in the result pytree. (default: + :data:`False`) + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(p, x, *xs)`` (not the return value) where ``(p, x)`` are the path and value at the + corresponding leaf in ``tree`` and ``xs`` is the tuple of values at values at corresponding + nodes in ``rests``. + """ + paths, leaves, treespec = tree_flatten_with_path(tree, is_leaf, none_is_leaf=none_is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + flat_results = map(func, paths, *flat_args) + deque(flat_results, maxlen=0) # consume and exhaust the iterable + return tree + + __INITIAL_MISSING: T = object() # type: ignore[valid-type] @@ -626,7 +831,7 @@ def treespec_leaf(*, none_is_leaf: bool = False) -> PyTreeSpec: >>> treespec_leaf() PyTreeSpec(*) >>> treespec_leaf(none_is_leaf=True) - PyTreeSpec(NoneIsLeaf, *) + PyTreeSpec(*, NoneIsLeaf) >>> treespec_leaf(none_is_leaf=False) == treespec_leaf(none_is_leaf=True) False >>> treespec_leaf() == tree_structure(1) @@ -657,7 +862,7 @@ def treespec_none(*, none_is_leaf: bool = False) -> PyTreeSpec: >>> treespec_none() PyTreeSpec(None) >>> treespec_none(none_is_leaf=True) - PyTreeSpec(NoneIsLeaf, *) + PyTreeSpec(*, NoneIsLeaf) >>> treespec_none(none_is_leaf=False) == treespec_none(none_is_leaf=True) False >>> treespec_none() == tree_structure(None) @@ -733,21 +938,31 @@ def add_leaves(x: T, subtree: PyTree[S]) -> None: def flatten_one_level( tree: PyTree[T], *, none_is_leaf: bool = False -) -> Tuple[Children[T], MetaData]: +) -> Tuple[Children[T], MetaData, Tuple[Any, ...]]: """Flattens the pytree one level, returning a tuple of children and auxiliary data.""" if tree is None: if none_is_leaf: # type: ignore[unreachable] raise ValueError('Cannot flatten None') - return (), None + return (), None, () node_type = type(tree) handler = register_pytree_node.get(node_type) # type: ignore[attr-defined] if handler: - children, metadata = handler.to_iter(tree) - return list(children), metadata + flattened = handler.to_iterable(tree) + if len(flattened) == 2: + flattened = (*flattened, None) + elif len(flattened) != 3: + raise ValueError('PyTree to_iterables must return a 2- or 3-tuple.') + children, metadata, entries = flattened + children = list(children) + if entries is None: + entries = tuple(range(len(children))) + else: + entries = tuple(entries) + return children, metadata, entries if is_namedtuple(tree): - return list(cast(NamedTuple, tree)), node_type + return list(cast(NamedTuple, tree)), node_type, tuple(range(len(cast(NamedTuple, tree)))) raise ValueError(f'Cannot tree-flatten type: {node_type}.') @@ -801,8 +1016,8 @@ def _prefix_error( # Or they may disagree if their roots have different numbers of children (note that because both # prefix_tree and full_tree have the same type at this point, and because prefix_tree is not a # leaf, each can be flattened once): - prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree) - full_tree_children, full_tree_meta = flatten_one_level(full_tree) + prefix_tree_children, prefix_tree_meta, _ = flatten_one_level(prefix_tree) + full_tree_children, full_tree_meta, _ = flatten_one_level(full_tree) if len(prefix_tree_children) != len(full_tree_children): yield lambda name: ValueError( f'pytree structure error: different numbers of pytree children at key path\n' diff --git a/optree/registry.py b/optree/registry.py index 4831a687..01ce9424 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -22,7 +22,7 @@ import optree._C as _C from optree.typing import KT, VT, Children, CustomTreeNode, DefaultDict, MetaData from optree.typing import OrderedDict as GenericOrderedDict -from optree.typing import PyTree, PyTreeSpec, T +from optree.typing import PyTree, T from optree.utils import safe_zip, unzip2 @@ -39,33 +39,35 @@ PyTreeNodeRegistryEntry = NamedTuple( 'PyTreeNodeRegistryEntry', [ - ('to_iter', Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]), - ('from_iter', Callable[[MetaData, Children[T]], CustomTreeNode[T]]), + ('to_iterable', Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]), + ('from_iterable', Callable[[MetaData, Children[T]], CustomTreeNode[T]]), ], ) def register_pytree_node( - type: Type[CustomTreeNode[T]], # pylint: disable=redefined-builtin + cls: Type[CustomTreeNode[T]], flatten_func: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]], unflatten_func: Callable[[MetaData, Children[T]], CustomTreeNode[T]], ) -> Type[CustomTreeNode[T]]: """Extends the set of types that are considered internal nodes in pytrees. Args: - type: A Python type to treat as an internal pytree node. - flatten_func: A function to be used during flattening, taking a value of type ``type`` and - returning a pair, with (1) an iterable for the children to be flattened recursively, and - (2) some hashable auxiliary data to be stored in the treespec and to be passed to the - ``unflatten_func``. + cls: A Python type to treat as an internal pytree node. + flatten_func: A function to be used during flattening, taking an instance of ``cls`` and + returning a triple or optionally a pair, with (1) an iterable for the children to be + flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec + and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree + path entries to the corresponding children. If the entries are not provided or given by + :data:`None`, then `range(len(children))` will be used. unflatten_func: A function taking two arguments: the auxiliary data that was returned by ``flatten_func`` and stored in the treespec, and the unflattened children. The function - should return an instance of ``type``. + should return an instance of ``cls``. """ - _C.register_node(type, flatten_func, unflatten_func) - CustomTreeNode.register(type) # pylint: disable=no-member - _nodetype_registry[type] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func) - return type + _C.register_node(cls, flatten_func, unflatten_func) + CustomTreeNode.register(cls) # pylint: disable=no-member + _nodetype_registry[cls] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func) + return cls def register_pytree_node_class(cls: Type[CustomTreeNode[T]]) -> Type[CustomTreeNode[T]]: @@ -125,23 +127,27 @@ def _sorted_keys(dct: Dict[KT, VT]) -> Iterable[KT]: # pragma: no cover return dct # fallback to insertion order -def _dict_flatten(dct: Dict[KT, VT]) -> Tuple[Tuple[VT, ...], Tuple[KT, ...]]: # pragma: no cover +def _dict_flatten( + dct: Dict[KT, VT] +) -> Tuple[Tuple[VT, ...], Tuple[KT, ...], Tuple[KT, ...]]: # pragma: no cover keys, values = unzip2(_sorted_items(dct.items())) - return values, keys + return values, keys, keys def _ordereddict_flatten( dct: GenericOrderedDict[KT, VT] -) -> Tuple[Tuple[VT, ...], Tuple[KT, ...]]: # pragma: no cover +) -> Tuple[Tuple[VT, ...], Tuple[KT, ...], Tuple[KT, ...]]: # pragma: no cover keys, values = unzip2(dct.items()) - return values, keys + return values, keys, keys def _defaultdict_flatten( dct: DefaultDict[KT, VT] -) -> Tuple[Tuple[VT, ...], Tuple[Optional[Callable[[], VT]], Tuple[KT, ...]]]: # pragma: no cover - values, keys = _dict_flatten(dct) - return values, (dct.default_factory, keys) +) -> Tuple[ + Tuple[VT, ...], Tuple[Optional[Callable[[], VT]], Tuple[KT, ...]], Tuple[KT, ...] +]: # pragma: no cover + values, keys, entries = _dict_flatten(dct) + return values, (dct.default_factory, keys), entries # pylint: disable=all @@ -341,11 +347,11 @@ def pprint(self) -> str: def register_keypaths( - type: Type[CustomTreeNode[T]], # pylint: disable=redefined-builtin + cls: Type[CustomTreeNode[T]], handler: KeyPathHandler, ) -> KeyPathHandler: """Registers a key path handler for a custom pytree node type.""" - _keypath_registry[type] = handler + _keypath_registry[cls] = handler return handler diff --git a/optree/typing.py b/optree/typing.py index 60c7409f..b5321ba4 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -22,6 +22,7 @@ DefaultDict, Deque, Dict, + ForwardRef, Generic, Hashable, Iterable, @@ -47,13 +48,6 @@ from typing import NamedTuple # type: ignore[assignment] -try: - # Python 3.6 - from typing import _ForwardRef as ForwardRef # type: ignore[attr-defined] -except ImportError: - from typing import ForwardRef - - __all__ = [ 'PyTreeSpec', 'PyTreeDef', @@ -98,7 +92,14 @@ class CustomTreeNode(Protocol[T]): """The abstract base class for custom pytree nodes.""" - def tree_flatten(self) -> Tuple[Children[T], MetaData]: + def tree_flatten( + self, + ) -> Union[ + # Use `range(num_children)` as path entries + Tuple[Children[T], MetaData], + # With optionally implemented path entries + Tuple[Children[T], MetaData, Optional[Iterable[Any]]], + ]: """Flattens the custom pytree node into children and auxiliary data.""" @classmethod @@ -144,7 +145,7 @@ class PyTree(Generic[T]): # pylint: disable=too-few-public-methods @_tp_cache def __class_getitem__(cls, item: Union[T, Tuple[T], Tuple[T, Optional[str]]]) -> TypeAlias: - """Instantiate a PyTree type with the given type.""" + """Instantiates a PyTree type with the given type.""" if not isinstance(item, tuple): item = (item, None) if len(item) != 2: @@ -167,7 +168,7 @@ def __class_getitem__(cls, item: Union[T, Tuple[T], Tuple[T, Optional[str]]]) -> return param # PyTree[PyTree[T]] -> PyTree[T] if name is not None: - recurse_ref = name + recurse_ref = ForwardRef(name) elif isinstance(param, TypeVar): recurse_ref = ForwardRef(f'{cls.__name__}[{param.__name__}]') elif isinstance(param, type): @@ -226,7 +227,7 @@ class PyTreeTypeVar: @_tp_cache def __new__(cls, name: str, param: Type) -> TypeAlias: - """Instantiate a PyTree type variable with the given name and parameter.""" + """Instantiates a PyTree type variable with the given name and parameter.""" if not isinstance(name, str): raise TypeError(f'{cls.__name__} only supports a string of type name. Got {name!r}.') return PyTree[param, name] # type: ignore[misc,valid-type] @@ -245,5 +246,5 @@ def __deepcopy__(self, memo): def is_namedtuple(obj: object) -> bool: - """Return whether the object is a namedtuple.""" + """Returns whether the object is a namedtuple.""" return isinstance(obj, tuple) and hasattr(obj, '_fields') diff --git a/pyproject.toml b/pyproject.toml index 440d7139..66b1d5aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" name = "optree" description = "Optimized PyTree Utilities." readme = "README.md" -requires-python = ">= 3.6" +requires-python = ">= 3.7" authors = [ { name = "OpTree Contributors" }, { name = "Xuehai Pan", email = "XuehaiPan@pku.edu.cn" }, @@ -25,7 +25,6 @@ classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -92,7 +91,7 @@ test-command = """make -C "{project}" test PYTHON=python""" safe = true line-length = 100 skip-string-normalization = true -target-version = ["py36", "py37", "py38", "py39", "py310", "py311"] +target-version = ["py37", "py38", "py39", "py310", "py311"] [tool.isort] profile = "black" @@ -104,7 +103,7 @@ lines_after_imports = 2 multi_line_output = 3 [tool.mypy] -python_version = 3.6 +python_version = 3.7 pretty = true show_error_codes = true show_error_context = true diff --git a/src/optree.cpp b/src/optree.cpp index bfefe298..1312c703 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -33,7 +33,7 @@ void BuildModule(py::module& mod) { // NOLINT &PyTreeTypeRegistry::Register, "Register a Python type. Extends the set of types that are considered internal nodes " "in pytrees.", - py::arg("type"), + py::arg("cls"), py::arg("to_iterable"), py::arg("from_iterable")); @@ -43,6 +43,12 @@ void BuildModule(py::module& mod) { // NOLINT py::arg("tree"), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false); + mod.def("flatten_with_path", + &PyTreeSpec::FlattenWithPath, + "Flattens a pytree and additionally records the paths.", + py::arg("tree"), + py::arg("leaf_predicate") = std::nullopt, + py::arg("none_is_leaf") = false); mod.def("all_leaves", &PyTreeSpec::AllLeaves, "Tests whether all elements in the given iterable are all leaves.", diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 13d36852..5e95a58a 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -86,8 +86,8 @@ void PyTreeSpec::FlattenIntoImpl(const py::handle& handle, py::tuple tuple = py::reinterpret_borrow(handle); node.arity = GET_SIZE(tuple); node.node_data = py::reinterpret_borrow(tuple.get_type()); - for (const py::handle& entry : tuple) { - recurse(entry); + for (ssize_t i = 0; i < node.arity; ++i) { + recurse(GET_ITEM_HANDLE(tuple, i)); } break; } @@ -104,9 +104,10 @@ void PyTreeSpec::FlattenIntoImpl(const py::handle& handle, case PyTreeKind::Custom: { py::tuple out = py::cast(node.custom->to_iterable(handle)); - if (GET_SIZE(out) != 2) [[unlikely]] { + const size_t num_out = GET_SIZE(out); + if (num_out != 2 && num_out != 3) [[unlikely]] { // NOLINT throw std::runtime_error( - "PyTree custom to_iterable function should return a pair."); + "PyTree custom to_iterable function should return a pair or a triple."); } node.arity = 0; node.node_data = GET_ITEM_BORROW(out, 1); @@ -115,6 +116,20 @@ void PyTreeSpec::FlattenIntoImpl(const py::handle& handle, ++node.arity; recurse(child); } + if (num_out == 3) [[likely]] { // NOLINT + py::object node_entries = GET_ITEM_BORROW(out, 2); + if (!node_entries.is_none()) [[likely]] { // NOLINT + node.node_entries = py::cast(std::move(node_entries)); + const ssize_t num_entries = GET_SIZE(node.node_entries); + if (num_entries != node.arity) [[unlikely]] { + throw std::runtime_error(absl::StrFormat( + "PyTree custom to_iterable function returned inconsistent number " + "of children (%ld) and number of entries (%ld).", + node.arity, + num_entries)); + } + } + } break; } @@ -160,6 +175,192 @@ void PyTreeSpec::FlattenInto(const py::handle& handle, return std::make_pair(std::move(leaves), std::move(treespec)); } +template +void PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, + Span& leaves, + Span& paths, + Stack& stack, + const ssize_t& depth, + const std::optional& leaf_predicate) { + Node node; + ssize_t start_num_nodes = traversal.size(); + ssize_t start_num_leaves = leaves.size(); + if (leaf_predicate && (*leaf_predicate)(handle).cast()) [[unlikely]] { + leaves.emplace_back(py::reinterpret_borrow(handle)); + } else [[likely]] { // NOLINT + node.kind = GetKind(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves, &paths, &stack, &depth](py::handle child, + py::handle entry) { + stack.emplace_back(entry); + FlattenIntoWithPathImpl( + child, leaves, paths, stack, depth + 1, leaf_predicate); + stack.pop_back(); + }; + switch (node.kind) { + case PyTreeKind::None: + if (!NoneIsLeaf) break; + case PyTreeKind::Leaf: { + py::tuple path = py::tuple{depth}; + for (ssize_t d = 0; d < depth; ++d) { + SET_ITEM(path, d, stack[d]); + } + leaves.emplace_back(py::reinterpret_borrow(handle)); + paths.emplace_back(std::move(path)); + break; + } + + case PyTreeKind::Tuple: { + node.arity = GET_SIZE(handle); + for (ssize_t i = 0; i < node.arity; ++i) { + recurse(GET_ITEM_HANDLE(handle, i), py::int_(i)); + } + break; + } + + case PyTreeKind::List: { + node.arity = GET_SIZE(handle); + for (ssize_t i = 0; i < node.arity; ++i) { + recurse(GET_ITEM_HANDLE(handle, i), py::int_(i)); + } + break; + } + + case PyTreeKind::Dict: + case PyTreeKind::OrderedDict: + case PyTreeKind::DefaultDict: { + py::dict dict = py::reinterpret_borrow(handle); + py::list keys; + if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + keys = DictKeys(dict); + } else [[likely]] { // NOLINT + keys = SortedDictKeys(dict); + } + for (const py::handle& key : keys) { + recurse(dict[key], key); + } + node.arity = GET_SIZE(dict); + if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + node.node_data = + py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys)); + } else [[likely]] { // NOLINT + node.node_data = std::move(keys); + } + break; + } + + case PyTreeKind::NamedTuple: { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = GET_SIZE(tuple); + node.node_data = py::reinterpret_borrow(tuple.get_type()); + for (ssize_t i = 0; i < node.arity; ++i) { + recurse(GET_ITEM_HANDLE(tuple, i), py::int_(i)); + } + break; + } + + case PyTreeKind::Deque: { + py::list list = handle.cast(); + node.arity = GET_SIZE(list); + node.node_data = py::getattr(handle, "maxlen"); + for (ssize_t i = 0; i < node.arity; ++i) { + recurse(GET_ITEM_HANDLE(list, i), py::int_(i)); + } + break; + } + + case PyTreeKind::Custom: { + py::tuple out = py::cast(node.custom->to_iterable(handle)); + const size_t num_out = GET_SIZE(out); + if (num_out != 2 && num_out != 3) [[unlikely]] { + throw std::runtime_error( + "PyTree custom to_iterable function should return a pair or a triple."); + } + node.arity = 0; + node.node_data = GET_ITEM_BORROW(out, 1); + py::object node_entries; + if (num_out == 3) [[likely]] { // NOLINT + node_entries = GET_ITEM_BORROW(out, 2); + } else [[unlikely]] { // NOLINT + node_entries = py::none(); + } + if (node_entries.is_none()) [[unlikely]] { // NOLINT + for (const py::handle& child : + py::cast(GET_ITEM_BORROW(out, 0))) { + recurse(child, py::int_(node.arity++)); + } + } else [[likely]] { // NOLINT + node.node_entries = py::cast(std::move(node_entries)); + node.arity = GET_SIZE(node.node_entries); + ssize_t num_children = 0; + for (const py::handle& child : + py::cast(GET_ITEM_BORROW(out, 0))) { + if (num_children >= node.arity) [[unlikely]] { + throw std::runtime_error( + "PyTree custom to_iterable function returned too many children " + "than number of entries."); + } + recurse(child, + GET_ITEM_BORROW(node.node_entries, num_children++)); + } + if (num_children != node.arity) [[unlikely]] { + throw std::runtime_error(absl::StrFormat( + "PyTree custom to_iterable function returned inconsistent " + "number " + "of children (%ld) and number of entries (%ld).", + num_children, + node.arity)); + } + } + break; + } + + default: + throw std::logic_error("Unreachable code."); + } + } + node.num_nodes = traversal.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal.emplace_back(std::move(node)); +} + +void PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, + absl::InlinedVector& leaves, + absl::InlinedVector& paths, + const std::optional& leaf_predicate, + const bool& none_is_leaf) { + absl::InlinedVector stack; + if (none_is_leaf) [[unlikely]] { + FlattenIntoWithPathImpl(handle, leaves, paths, stack, 0, leaf_predicate); + } else [[likely]] { // NOLINT + FlattenIntoWithPathImpl(handle, leaves, paths, stack, 0, leaf_predicate); + } +} + +void PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, + std::vector& leaves, + std::vector& paths, + const std::optional& leaf_predicate, + const bool& none_is_leaf) { + std::vector stack; + if (none_is_leaf) [[unlikely]] { + FlattenIntoWithPathImpl(handle, leaves, paths, stack, 0, leaf_predicate); + } else [[likely]] { // NOLINT + FlattenIntoWithPathImpl(handle, leaves, paths, stack, 0, leaf_predicate); + } +} + +/*static*/ std::tuple, std::vector, std::unique_ptr> +PyTreeSpec::FlattenWithPath(const py::handle& tree, + const std::optional& leaf_predicate, + const bool& none_is_leaf) { + std::vector leaves; + std::vector paths; + auto treespec = std::make_unique(); + treespec->none_is_leaf = none_is_leaf; + treespec->FlattenIntoWithPath(tree, leaves, paths, leaf_predicate, none_is_leaf); + return std::make_tuple(std::move(paths), std::move(leaves), std::move(treespec)); +} + py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { const ssize_t num_leaves = PyTreeSpec::num_leaves(); @@ -201,8 +402,8 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { node.arity, py::repr(object))); } - for (const py::handle& entry : tuple) { - agenda.emplace_back(py::reinterpret_borrow(entry)); + for (ssize_t i = 0; i < node.arity; ++i) { + agenda.emplace_back(GET_ITEM_BORROW(tuple, i)); } break; } @@ -217,8 +418,8 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { node.arity, py::repr(object))); } - for (const py::handle& entry : list) { - agenda.emplace_back(py::reinterpret_borrow(entry)); + for (ssize_t i = 0; i < node.arity; ++i) { + agenda.emplace_back(GET_ITEM_BORROW(list, i)); } break; } @@ -265,8 +466,8 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { py::repr(node.node_data), py::repr(object))); } - for (const py::handle& entry : tuple) { - agenda.emplace_back(py::reinterpret_borrow(entry)); + for (ssize_t i = 0; i < node.arity; ++i) { + agenda.emplace_back(GET_ITEM_BORROW(tuple, i)); } break; } @@ -307,8 +508,8 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { node.arity, py::repr(object))); } - for (const py::handle& entry : list) { - agenda.emplace_back(py::reinterpret_borrow(entry)); + for (ssize_t i = 0; i < node.arity; ++i) { + agenda.emplace_back(GET_ITEM_BORROW(list, i)); } break; } @@ -327,9 +528,10 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { py::repr(object))); } py::tuple out = py::cast(node.custom->to_iterable(object)); - if (GET_SIZE(out) != 2) [[unlikely]] { + const size_t num_out = GET_SIZE(out); + if (num_out != 2 && num_out != 3) [[unlikely]] { throw std::runtime_error( - "PyTree custom to_iterable function should return a pair."); + "PyTree custom to_iterable function should return a pair or a triple."); } if (node.node_data.not_equal(GET_ITEM_BORROW(out, 1))) [[unlikely]] { throw std::invalid_argument( @@ -339,10 +541,10 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { py::repr(object))); } ssize_t arity = 0; - for (const py::handle& entry : + for (const py::handle& child : py::cast(GET_ITEM_BORROW(out, 0))) { ++arity; - agenda.emplace_back(py::reinterpret_borrow(entry)); + agenda.emplace_back(py::reinterpret_borrow(child)); } if (arity != node.arity) [[unlikely]] { throw std::invalid_argument( diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index 3ad17983..a720ba7d 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -58,7 +58,7 @@ bool PyTreeSpec::operator==(const PyTreeSpec& other) const { bool PyTreeSpec::operator!=(const PyTreeSpec& other) const { return !(*this == other); } std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec) const { - if (inner_treespec.none_is_leaf != none_is_leaf) { + if (inner_treespec.none_is_leaf != none_is_leaf) [[unlikely]] { // NOLINT throw std::invalid_argument("PyTreeSpecs must have the same none_is_leaf value."); } auto outer_treespec = std::make_unique(); @@ -225,7 +225,7 @@ std::vector> PyTreeSpec::Children() const { } case PyTreeKind::Custom: { - py::tuple tuple(node.arity); + py::tuple tuple{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { SET_ITEM(tuple, i, std::move(children[i])); } @@ -404,7 +404,7 @@ std::string PyTreeSpec::ToString() const { if (agenda.size() != 1) [[unlikely]] { throw std::logic_error("PyTreeSpec traversal did not yield a singleton."); } - return absl::StrCat("PyTreeSpec(", (none_is_leaf ? "NoneIsLeaf, " : ""), agenda.back(), ")"); + return absl::StrCat("PyTreeSpec(", agenda.back(), (none_is_leaf ? ", NoneIsLeaf" : ""), ")"); } py::object PyTreeSpec::ToPicklable() const { @@ -416,6 +416,7 @@ py::object PyTreeSpec::ToPicklable() const { py::make_tuple(static_cast(node.kind), node.arity, node.node_data ? node.node_data : py::none(), + node.node_entries ? node.node_entries : py::none(), node.custom != nullptr ? node.custom->type : py::none(), node.num_leaves, node.num_nodes)); @@ -434,7 +435,7 @@ py::object PyTreeSpec::ToPicklable() const { py::tuple node_states = py::reinterpret_borrow(state[0]); for (const auto& item : node_states) { auto t = item.cast(); - if (t.size() != 6) [[unlikely]] { + if (t.size() != 7) [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); } Node& node = treespec.traversal.emplace_back(); @@ -460,25 +461,28 @@ py::object PyTreeSpec::ToPicklable() const { break; } if (node.kind == PyTreeKind::Custom) [[unlikely]] { // NOLINT - if (t[3].is_none()) [[unlikely]] { + if (!t[3].is_none()) [[unlikely]] { + node.node_entries = t[3].cast(); + } + if (t[4].is_none()) [[unlikely]] { node.custom = nullptr; } else [[likely]] { // NOLINT if (none_is_leaf) [[unlikely]] { - node.custom = PyTreeTypeRegistry::Lookup(t[3]); + node.custom = PyTreeTypeRegistry::Lookup(t[4]); } else [[likely]] { // NOLINT - node.custom = PyTreeTypeRegistry::Lookup(t[3]); + node.custom = PyTreeTypeRegistry::Lookup(t[4]); } } if (node.custom == nullptr) [[unlikely]] { throw std::runtime_error(absl::StrCat("Unknown custom type in pickled PyTreeSpec: ", - static_cast(py::repr(t[3])), + static_cast(py::repr(t[4])), ".")); } - } else if (!t[3].is_none()) [[unlikely]] { + } else if (!t[3].is_none() || !t[4].is_none()) [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); } - node.num_leaves = t[4].cast(); - node.num_nodes = t[5].cast(); + node.num_leaves = t[5].cast(); + node.num_nodes = t[6].cast(); } return treespec; } diff --git a/tests/helpers.py b/tests/helpers.py index 7336a6d1..2a6934d6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,7 @@ # pylint: disable=missing-class-docstring,missing-function-docstring,invalid-name import itertools -from collections import OrderedDict, defaultdict, deque, namedtuple +from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple import pytest @@ -113,9 +113,24 @@ def tree_flatten(self): def tree_unflatten(cls, metadata, children): if not optree.all_leaves(children): children, metadata = optree.tree_flatten(optree.tree_unflatten(metadata, children)) - return FlatCache(structured=None, leaves=children, treespec=metadata) + return cls(structured=None, leaves=children, treespec=metadata) +@optree.register_pytree_node_class +class MyDict(UserDict): + def tree_flatten(self): + reversed_keys = sorted(self.keys(), reverse=True) + return [self[key] for key in reversed_keys], reversed_keys, reversed_keys + + @classmethod + def tree_unflatten(cls, metadata, children): + return cls(zip(metadata, children)) + + def __repr__(self): + return f'{self.__class__.__name__}({super().__repr__()})' + + +# pylint: disable=line-too-long TREES = ( 1, None, @@ -135,6 +150,7 @@ def tree_unflatten(cls, metadata, children): OrderedDict([('foo', 34), ('baz', 101), ('something', deque([None, None, 3], maxlen=2))]), OrderedDict([('foo', 34), ('baz', 101), ('something', deque([None, 2, 3], maxlen=2))]), defaultdict(dict, [('foo', 34), ('baz', 101), ('something', -42)]), + MyDict([('baz', 101), ('foo', MyDict(a=1, b=2, c=None))]), CustomNamedTupleSubclass(foo='hello', bar=3.5), FlatCache(None), FlatCache(1), @@ -142,6 +158,64 @@ def tree_unflatten(cls, metadata, children): ) +TREE_PATHS_NONE_IS_NODE = [ + [()], + [], + [], + [(0,)], + [], + [], + [(0,), (1,)], + [(0, 0), (0, 1), (1, 0), (1, 1, 0), (1, 1, 2)], + [(0,)], + [(0,), (1, 0, 0), (1, 0, 1, 0), (1, 1, 'baz')], + [(0, 0)], + [(0,), (1,)], + [('a',), ('b',)], + [('foo',), ('baz',), ('something',)], + [('foo',), ('baz',), ('something', 1), ('something', 2)], + [('foo',), ('baz',), ('something', 1)], + [('foo',), ('baz',), ('something', 0), ('something', 1)], + [('baz',), ('foo',), ('something',)], + [('foo', 'b'), ('foo', 'a'), ('baz',)], + [(0,), (1,)], + [], + [(0,)], + [(0,), (1,)], +] + +TREE_PATHS_NONE_IS_LEAF = [ + [()], + [()], + [(0,)], + [(0,), (1,)], + [], + [], + [(0,), (1,)], + [(0, 0), (0, 1), (1, 0), (1, 1, 0), (1, 1, 1), (1, 1, 2)], + [(0,)], + [(0,), (1, 0, 0), (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 'baz')], + [(0, 0), (0, 1)], + [(0,), (1,)], + [('a',), ('b',)], + [('foo',), ('baz',), ('something',)], + [('foo',), ('baz',), ('something', 0), ('something', 1), ('something', 2)], + [('foo',), ('baz',), ('something', 0), ('something', 1)], + [('foo',), ('baz',), ('something', 0), ('something', 1)], + [('baz',), ('foo',), ('something',)], + [('foo', 'c'), ('foo', 'b'), ('foo', 'a'), ('baz',)], + [(0,), (1,)], + [], + [(0,)], + [(0,), (1,)], +] + +TREE_PATHS = { + optree.NONE_IS_NODE: TREE_PATHS_NONE_IS_NODE, + optree.NONE_IS_LEAF: TREE_PATHS_NONE_IS_LEAF, +} + + TREE_STRINGS_NONE_IS_NODE = ( 'PyTreeSpec(*)', 'PyTreeSpec(None)', @@ -161,6 +235,7 @@ def tree_unflatten(cls, metadata, children): "PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([None, *], maxlen=2))]))", "PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]))", "PyTreeSpec(defaultdict(, {'baz': *, 'foo': *, 'something': *}))", + "PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [None, *, *]), *]))", 'PyTreeSpec(CustomNamedTupleSubclass(foo=*, bar=*))', 'PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec(None)], []))', 'PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec(*)], [*]))', @@ -168,28 +243,29 @@ def tree_unflatten(cls, metadata, children): ) TREE_STRINGS_NONE_IS_LEAF = ( - 'PyTreeSpec(NoneIsLeaf, *)', - 'PyTreeSpec(NoneIsLeaf, *)', - 'PyTreeSpec(NoneIsLeaf, (*,))', - 'PyTreeSpec(NoneIsLeaf, (*, *))', - 'PyTreeSpec(NoneIsLeaf, ())', - 'PyTreeSpec(NoneIsLeaf, [()])', - 'PyTreeSpec(NoneIsLeaf, (*, *))', - 'PyTreeSpec(NoneIsLeaf, ((*, *), [*, (*, *, *)]))', - 'PyTreeSpec(NoneIsLeaf, [*])', - "PyTreeSpec(NoneIsLeaf, [*, CustomTuple(foo=(*, CustomTuple(foo=*, bar=*)), bar={'baz': *})])", - "PyTreeSpec(NoneIsLeaf, [CustomTreeNode(Vector3D[[4, 'foo']], [*, *])])", - 'PyTreeSpec(NoneIsLeaf, CustomTreeNode(Vector2D[None], [*, *]))', - "PyTreeSpec(NoneIsLeaf, {'a': *, 'b': *})", - "PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', *)]))", - "PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *, *]))]))", - "PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]))", - "PyTreeSpec(NoneIsLeaf, OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]))", - "PyTreeSpec(NoneIsLeaf, defaultdict(, {'baz': *, 'foo': *, 'something': *}))", - 'PyTreeSpec(NoneIsLeaf, CustomNamedTupleSubclass(foo=*, bar=*))', - 'PyTreeSpec(NoneIsLeaf, CustomTreeNode(FlatCache[PyTreeSpec(None)], []))', - 'PyTreeSpec(NoneIsLeaf, CustomTreeNode(FlatCache[PyTreeSpec(*)], [*]))', - "PyTreeSpec(NoneIsLeaf, CustomTreeNode(FlatCache[PyTreeSpec({'a': [*, *]})], [*, *]))", + 'PyTreeSpec(*, NoneIsLeaf)', + 'PyTreeSpec(*, NoneIsLeaf)', + 'PyTreeSpec((*,), NoneIsLeaf)', + 'PyTreeSpec((*, *), NoneIsLeaf)', + 'PyTreeSpec((), NoneIsLeaf)', + 'PyTreeSpec([()], NoneIsLeaf)', + 'PyTreeSpec((*, *), NoneIsLeaf)', + 'PyTreeSpec(((*, *), [*, (*, *, *)]), NoneIsLeaf)', + 'PyTreeSpec([*], NoneIsLeaf)', + "PyTreeSpec([*, CustomTuple(foo=(*, CustomTuple(foo=*, bar=*)), bar={'baz': *})], NoneIsLeaf)", + "PyTreeSpec([CustomTreeNode(Vector3D[[4, 'foo']], [*, *])], NoneIsLeaf)", + 'PyTreeSpec(CustomTreeNode(Vector2D[None], [*, *]), NoneIsLeaf)', + "PyTreeSpec({'a': *, 'b': *}, NoneIsLeaf)", + "PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', *)]), NoneIsLeaf)", + "PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *, *]))]), NoneIsLeaf)", + "PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]), NoneIsLeaf)", + "PyTreeSpec(OrderedDict([('foo', *), ('baz', *), ('something', deque([*, *], maxlen=2))]), NoneIsLeaf)", + "PyTreeSpec(defaultdict(, {'baz': *, 'foo': *, 'something': *}), NoneIsLeaf)", + "PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *]), *]), NoneIsLeaf)", + 'PyTreeSpec(CustomNamedTupleSubclass(foo=*, bar=*), NoneIsLeaf)', + 'PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec(None)], []), NoneIsLeaf)', + 'PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec(*)], [*]), NoneIsLeaf)', + "PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec({'a': [*, *]})], [*, *]), NoneIsLeaf)", ) TREE_STRINGS = { diff --git a/tests/test_ops.py b/tests/test_ops.py index 705a444e..35b62905 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -16,6 +16,7 @@ # pylint: disable=missing-function-docstring,invalid-name import functools +import itertools from collections import OrderedDict import pytest @@ -23,7 +24,7 @@ import optree # pylint: disable-next=wrong-import-order -from helpers import LEAVES, TREES, CustomTuple, FlatCache, parametrize +from helpers import LEAVES, TREE_PATHS, TREES, CustomTuple, FlatCache, parametrize def dummy_func(*args, **kwargs): # pylint: disable=unused-argument @@ -146,6 +147,26 @@ def test_structure_is_leaf(structure_fn): assert treespec.num_leaves == 3 +@parametrize( + data=list( + itertools.chain( + zip(TREES, TREE_PATHS[False], itertools.repeat(False)), + zip(TREES, TREE_PATHS[True], itertools.repeat(True)), + ) + ) +) +def test_paths(data): + tree, expected_paths, none_is_leaf = data + expected_leaves, expected_treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) + paths, leaves, treespec = optree.tree_flatten_with_path(tree, none_is_leaf=none_is_leaf) + assert len(paths) == len(leaves) + assert leaves == expected_leaves + assert treespec == expected_treespec + assert paths == expected_paths + paths = optree.tree_paths(tree, none_is_leaf=none_is_leaf) + assert paths == expected_paths + + @parametrize( tree=TREES, is_leaf=[ @@ -198,6 +219,16 @@ def test_tree_map(): assert out == (((1, [3]), (2, None), None), ((3, {'foo': 'bar'}), (4, 7), (5, [5, 6]))) +def test_tree_map_with_path(): + x = ((1, 2, None), (3, 4, 5)) + y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6])) + out = optree.tree_map_with_path(lambda *xs: tuple(xs), x, y) + assert out == ( + (((0, 0), 1, [3]), ((0, 1), 2, None), None), + (((1, 0), 3, {'foo': 'bar'}), ((1, 1), 4, 7), ((1, 2), 5, [5, 6])), + ) + + def test_tree_map_none_is_leaf(): x = ((1, 2, None), (3, 4, 5)) y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6])) @@ -205,6 +236,16 @@ def test_tree_map_none_is_leaf(): assert out == (((1, [3]), (2, None), (None, 4)), ((3, {'foo': 'bar'}), (4, 7), (5, [5, 6]))) +def test_tree_map_with_path_none_is_leaf(): + x = ((1, 2, None), (3, 4, 5)) + y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6])) + out = optree.tree_map_with_path(lambda *xs: tuple(xs), x, y, none_is_leaf=True) + assert out == ( + (((0, 0), 1, [3]), ((0, 1), 2, None), ((0, 2), None, 4)), + (((1, 0), 3, {'foo': 'bar'}), ((1, 1), 4, 7), ((1, 2), 5, [5, 6])), + ) + + def test_tree_map_with_is_leaf_none(): x = ((1, 2, None), (3, 4, 5)) out = optree.tree_map(lambda *xs: tuple(xs), x, none_is_leaf=False) @@ -246,19 +287,36 @@ def fn(*xs): assert leaves == [(1, [3]), (2, None), (3, {'foo': 'bar'}), (4, 7), (5, [5, 6])] -def test_tree_map_ignore_return_none_is_leaf(): +def test_tree_map_with_path_ignore_return(): x = ((1, 2, None), (3, 4, 5)) y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6])) - def fn(*xs): - leaves.append(xs) + def fn(p, *xs): + if p[1] >= 1: + leaves.append(xs) + return 0 + + leaves = [] + out = optree.tree_map_with_path_(fn, x, y) + assert out is x + assert x == ((1, 2, None), (3, 4, 5)) + assert leaves == [(2, None), (4, 7), (5, [5, 6])] + + +def test_tree_map_with_path_ignore_return_none_is_leaf(): + x = ((1, 2, None), (3, 4, 5)) + y = (([3], None, 4), ({'foo': 'bar'}, 7, [5, 6])) + + def fn(p, *xs): + if p[1] >= 1: + leaves.append(xs) return 0 leaves = [] - out = optree.tree_map_(fn, x, y, none_is_leaf=True) + out = optree.tree_map_with_path_(fn, x, y, none_is_leaf=True) assert out is x assert x == ((1, 2, None), (3, 4, 5)) - assert leaves == [(1, [3]), (2, None), (None, 4), (3, {'foo': 'bar'}), (4, 7), (5, [5, 6])] + assert leaves == [(2, None), (None, 4), (4, 7), (5, [5, 6])] def test_tree_map_inplace(): @@ -293,6 +351,38 @@ def fn_(x, y): assert x == ((Counter(4), Counter(2), None), (Counter(8), Counter(11), Counter(11))) +def test_tree_map_with_path_inplace(): + class Counter: + def __init__(self, start): + self.count = start + + def increment(self, n=1): + self.count += n + return self.count + + def __int__(self): + return self.count + + def __eq__(self, other): + return isinstance(other, Counter) and self.count == other.count + + def __repr__(self): + return f'Counter({self.count})' + + def __next__(self): + return self.increment() + + x = ((Counter(1), Counter(2), None), (Counter(3), Counter(4), Counter(5))) + y = ((3, 0, 4), (5, 7, 6)) + + def fn_(p, x, y): + x.increment(y * (1 + sum(p))) + + out = optree.tree_map_with_path_(fn_, x, y) + assert out is x + assert x == ((Counter(4), Counter(2), None), (Counter(13), Counter(25), Counter(29))) + + def test_tree_reduce(): assert optree.tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)}) == 6 assert optree.tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) == 6