Skip to content

Commit

Permalink
feat(ops): add tree flatten and tree map functions with extra paths (#11
Browse files Browse the repository at this point in the history
)
  • Loading branch information
XuehaiPan authored Nov 15, 2022
1 parent e8298b1 commit f2b31a4
Show file tree
Hide file tree
Showing 17 changed files with 794 additions and 132 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.6", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
timeout-minutes: 30
steps:
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ persistent=yes

# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.6
py-version=3.7

# Discover python modules and packages in the file system subtree.
recursive=no
Expand Down
35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# OpTree

![Python 3.6+](https://img.shields.io/badge/Python-3.6%2B-brightgreen)
![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen)
[![PyPI](https://img.shields.io/pypi/v/optree?logo=pypi)](https://pypi.org/project/optree)
![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/optree/Build?label=build&logo=github)
![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/optree/Tests?label=tests&logo=github)
Expand Down Expand Up @@ -61,7 +61,7 @@ cd optree
pip3 install .
```

Compiling from the source requires Python 3.6+, a compiler (`gcc` / `clang` / `icc` / `cl.exe`) supports C++20 and a `cmake` installation.
Compiling from the source requires Python 3.7+, a compiler (`gcc` / `clang` / `icc` / `cl.exe`) supports C++20 and a `cmake` installation.

--------------------------------------------------------------------------------

Expand Down Expand Up @@ -122,17 +122,19 @@ The `NoneType` is a special case discussed in section [`None` is non-leaf Node v

A container-like Python type can be registered in the container registry with a pair of functions that specify:

- `flatten_func(container) -> (children, metadata)`: convert an instance of the container type to a `(children, metadata)` pair, where `children` is an iterable of subtrees.
- `flatten_func(container) -> (children, metadata, entries)`: convert an instance of the container type to a `(children, metadata, entries)` triple, where `children` is an iterable of subtrees and `entries` is an iterable of path entries of the container (e.g., indices or keys).
- `unflatten_func(metadata, children) -> container`: convert such a pair back to an instance of the container type.

The `metadata` is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).

The `entries` can be omitted (only returns a pair) or is optional to implement (returns `None`). If so, use `range(len(children))` (i.e., flat indices) as path entries of the current node. The function signature can be `flatten_func(container) -> (children, metadata)` or `flatten_func(container) -> (children, metadata, None)`.

```python
>>> import torch

>>> optree.register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... flatten_func=lambda tensor: ( # -> (children, metadata)
... (tensor.cpu().numpy(),),
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
... ),
Expand All @@ -154,6 +156,9 @@ The `metadata` is some necessary data apart from the children to reconstruct the
})
)

>>> optree.tree_paths(tree) # entries are not defined and use `range(len(children))`
[('bias', 0), ('weight', 0)]

>>> optree.tree_unflatten(treespec, leaves)
{'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')}
```
Expand All @@ -165,16 +170,24 @@ Users can also extend the pytree registry by decorating the custom class and def
...
... @optree.register_pytree_node_class
... class MyDict(UserDict):
... def tree_flatten(self):
... def tree_flatten(self): # -> (children, metadata, entries)
... reversed_keys = sorted(self.keys(), reverse=True)
... return [self[key] for key in reversed_keys], reversed_keys
... return (
... [self[key] for key in reversed_keys], # children
... reversed_keys, # metadata
... reversed_keys, # entries
... )
...
... @classmethod
... def tree_unflatten(cls, metadata, children):
... return MyDict(zip(metadata, children))

>>> optree.tree_flatten(MyDict(b=2, a=1, c=3))
([3, 2, 1], PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *])))
>>> optree.tree_flatten_with_path(MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6})))
(
[('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
[6, 5, 4, 2, 3],
PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]))
)
```

#### Limitations of the PyTree Type Registry
Expand All @@ -195,21 +208,21 @@ It may also serve as a sentinel value.
In addition, if a function has returned without any return value, it also implicitly returns the `None` object.

By default, the `None` object is considered a non-leaf node in the tree with arity 0, i.e., _**a non-leaf node that has no children**_.
This is slightly different than the definition of a non-leaf node as discussed above.
This is like the behavior of an empty tuple.
While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.

```python
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(NoneIsLeaf, *))
([None], PyTreeSpec(*, NoneIsLeaf))
```

OpTree provides a keyword argument `none_is_leaf` to determine whether to consider the `None` object as a leaf, like other opaque objects.
Expand Down
2 changes: 1 addition & 1 deletion include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PyTreeTypeRegistry {
// The following values are populated for custom types.
// The Python type object, used to identify the type.
py::object type;
// A function with signature: object -> (iterable, metadata)
// A function with signature: object -> (iterable, metadata, entries)
py::function to_iterable;
// A function with signature: (metadata, iterable) -> object
py::function from_iterable;
Expand Down
36 changes: 36 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -64,6 +65,26 @@ class PyTreeSpec {
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);

// Flattens a PyTree into a list of leaves with a list of paths and a PyTreeSpec.
// Returns references to the flattened objects, which might be temporary objects in the case of
// custom PyType handlers.
static std::tuple<std::vector<py::object>, std::vector<py::object>, std::unique_ptr<PyTreeSpec>>
FlattenWithPath(const py::handle &tree,
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);

// Recursive helper used to implement FlattenWithPath().
void FlattenIntoWithPath(const py::handle &handle,
std::vector<py::object> &leaves, // NOLINT
std::vector<py::object> &paths, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
void FlattenIntoWithPath(const py::handle &handle,
absl::InlinedVector<py::object, 2> &leaves, // NOLINT
absl::InlinedVector<py::object, 2> &paths, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);

// Flattens a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure
// of 'x'. For example, if we flatten a value [(1, (2, 3)), {"foo": 4}] with a PyTreeSpec [(*,
// *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}].
Expand Down Expand Up @@ -146,6 +167,13 @@ class PyTreeSpec {
// For a Custom type, contains the auxiliary data returned by the `to_iterable` function.
py::object node_data;

// The tuple of path entries.
// This is optional, if not specified, `range(arity)` is used.
// For a sequence, contains the index of the element.
// For a mapping, contains the key of the element.
// For a Custom type, contains the path entries returned by the `to_iterable` function.
py::object node_entries;

// Custom type registration. Must be null for non-custom types.
const PyTreeTypeRegistry::Registration *custom = nullptr;

Expand Down Expand Up @@ -176,6 +204,14 @@ class PyTreeSpec {
Span &leaves, // NOLINT
const std::optional<py::function> &leaf_predicate);

template <bool NoneIsLeaf, typename Span, typename Stack>
void FlattenIntoWithPathImpl(const py::handle &handle,
Span &leaves, // NOLINT
Span &paths, // NOLINT
Stack &stack, // NOLINT
const ssize_t &depth,
const std::optional<py::function> &leaf_predicate);

py::list FlattenUpToImpl(const py::handle &full_tree) const;

template <bool NoneIsLeaf>
Expand Down
1 change: 1 addition & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand Down
11 changes: 8 additions & 3 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
# pylint: disable=all
# isort: off

from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Sequence, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type

if TYPE_CHECKING:
from optree.typing import MetaData, Children, CustomTreeNode, PyTree, T, U
from optree.typing import Children, CustomTreeNode, MetaData, PyTree, T, U

version: int

Expand All @@ -28,6 +28,11 @@ def flatten(
leaf_predicate: Optional[Callable[[T], bool]] = None,
node_is_leaf: bool = False,
) -> Tuple[List[T], 'PyTreeSpec']: ...
def flatten_with_path(
tree: PyTree[T],
leaf_predicate: Optional[Callable[[T], bool]] = None,
node_is_leaf: bool = False,
) -> Tuple[List[Tuple[Any, ...]], List[T], 'PyTreeSpec']: ...
def all_leaves(iterable: Iterable[T], node_is_leaf: bool = False) -> bool: ...
def leaf(node_is_leaf: bool = False) -> 'PyTreeSpec': ...
def none(node_is_leaf: bool = False) -> 'PyTreeSpec': ...
Expand All @@ -53,7 +58,7 @@ class PyTreeSpec:
def __hash__(self) -> int: ...

def register_node(
type: Type[CustomTreeNode[T]],
cls: Type[CustomTreeNode[T]],
to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
) -> None: ...
8 changes: 8 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
tree_all,
tree_any,
tree_flatten,
tree_flatten_with_path,
tree_leaves,
tree_map,
tree_map_,
tree_map_with_path,
tree_map_with_path_,
tree_paths,
tree_reduce,
tree_replace_nones,
tree_structure,
Expand Down Expand Up @@ -60,12 +64,16 @@
'NONE_IS_NODE',
'NONE_IS_LEAF',
'tree_flatten',
'tree_flatten_with_path',
'tree_unflatten',
'tree_leaves',
'tree_structure',
'tree_paths',
'all_leaves',
'tree_map',
'tree_map_',
'tree_map_with_path',
'tree_map_with_path_',
'tree_reduce',
'tree_transpose',
'tree_replace_nones',
Expand Down
Loading

0 comments on commit f2b31a4

Please sign in to comment.