Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ops): add tree flatten and tree map functions with extra paths #11

Merged
merged 3 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.6", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
timeout-minutes: 30
steps:
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ persistent=yes

# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.6
py-version=3.7

# Discover python modules and packages in the file system subtree.
recursive=no
Expand Down
35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# OpTree

![Python 3.6+](https://img.shields.io/badge/Python-3.6%2B-brightgreen)
![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen)
[![PyPI](https://img.shields.io/pypi/v/optree?logo=pypi)](https://pypi.org/project/optree)
![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/optree/Build?label=build&logo=github)
![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/optree/Tests?label=tests&logo=github)
Expand Down Expand Up @@ -61,7 +61,7 @@ cd optree
pip3 install .
```

Compiling from the source requires Python 3.6+, a compiler (`gcc` / `clang` / `icc` / `cl.exe`) supports C++20 and a `cmake` installation.
Compiling from the source requires Python 3.7+, a compiler (`gcc` / `clang` / `icc` / `cl.exe`) supports C++20 and a `cmake` installation.

--------------------------------------------------------------------------------

Expand Down Expand Up @@ -122,17 +122,19 @@ The `NoneType` is a special case discussed in section [`None` is non-leaf Node v

A container-like Python type can be registered in the container registry with a pair of functions that specify:

- `flatten_func(container) -> (children, metadata)`: convert an instance of the container type to a `(children, metadata)` pair, where `children` is an iterable of subtrees.
- `flatten_func(container) -> (children, metadata, entries)`: convert an instance of the container type to a `(children, metadata, entries)` triple, where `children` is an iterable of subtrees and `entries` is an iterable of path entries of the container (e.g., indices or keys).
- `unflatten_func(metadata, children) -> container`: convert such a pair back to an instance of the container type.

The `metadata` is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).

The `entries` can be omitted (only returns a pair) or is optional to implement (returns `None`). If so, use `range(len(children))` (i.e., flat indices) as path entries of the current node. The function signature can be `flatten_func(container) -> (children, metadata)` or `flatten_func(container) -> (children, metadata, None)`.

```python
>>> import torch

>>> optree.register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... flatten_func=lambda tensor: ( # -> (children, metadata)
... (tensor.cpu().numpy(),),
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
... ),
Expand All @@ -154,6 +156,9 @@ The `metadata` is some necessary data apart from the children to reconstruct the
})
)

>>> optree.tree_paths(tree) # entries are not defined and use `range(len(children))`
[('bias', 0), ('weight', 0)]

>>> optree.tree_unflatten(treespec, leaves)
{'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')}
```
Expand All @@ -165,16 +170,24 @@ Users can also extend the pytree registry by decorating the custom class and def
...
... @optree.register_pytree_node_class
... class MyDict(UserDict):
... def tree_flatten(self):
... def tree_flatten(self): # -> (children, metadata, entries)
... reversed_keys = sorted(self.keys(), reverse=True)
... return [self[key] for key in reversed_keys], reversed_keys
... return (
... [self[key] for key in reversed_keys], # children
... reversed_keys, # metadata
... reversed_keys, # entries
... )
...
... @classmethod
... def tree_unflatten(cls, metadata, children):
... return MyDict(zip(metadata, children))

>>> optree.tree_flatten(MyDict(b=2, a=1, c=3))
([3, 2, 1], PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *])))
>>> optree.tree_flatten_with_path(MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6})))
(
[('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
[6, 5, 4, 2, 3],
PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]))
)
```

#### Limitations of the PyTree Type Registry
Expand All @@ -195,21 +208,21 @@ It may also serve as a sentinel value.
In addition, if a function has returned without any return value, it also implicitly returns the `None` object.

By default, the `None` object is considered a non-leaf node in the tree with arity 0, i.e., _**a non-leaf node that has no children**_.
This is slightly different than the definition of a non-leaf node as discussed above.
This is like the behavior of an empty tuple.
While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.

```python
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(NoneIsLeaf, *))
([None], PyTreeSpec(*, NoneIsLeaf))
```

OpTree provides a keyword argument `none_is_leaf` to determine whether to consider the `None` object as a leaf, like other opaque objects.
Expand Down
2 changes: 1 addition & 1 deletion include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PyTreeTypeRegistry {
// The following values are populated for custom types.
// The Python type object, used to identify the type.
py::object type;
// A function with signature: object -> (iterable, metadata)
// A function with signature: object -> (iterable, metadata, entries)
py::function to_iterable;
// A function with signature: (metadata, iterable) -> object
py::function from_iterable;
Expand Down
36 changes: 36 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -64,6 +65,26 @@ class PyTreeSpec {
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);

// Flattens a PyTree into a list of leaves with a list of paths and a PyTreeSpec.
// Returns references to the flattened objects, which might be temporary objects in the case of
// custom PyType handlers.
static std::tuple<std::vector<py::object>, std::vector<py::object>, std::unique_ptr<PyTreeSpec>>
FlattenWithPath(const py::handle &tree,
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);

// Recursive helper used to implement FlattenWithPath().
void FlattenIntoWithPath(const py::handle &handle,
std::vector<py::object> &leaves, // NOLINT
std::vector<py::object> &paths, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
void FlattenIntoWithPath(const py::handle &handle,
absl::InlinedVector<py::object, 2> &leaves, // NOLINT
absl::InlinedVector<py::object, 2> &paths, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);

// Flattens a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure
// of 'x'. For example, if we flatten a value [(1, (2, 3)), {"foo": 4}] with a PyTreeSpec [(*,
// *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}].
Expand Down Expand Up @@ -146,6 +167,13 @@ class PyTreeSpec {
// For a Custom type, contains the auxiliary data returned by the `to_iterable` function.
py::object node_data;

// The tuple of path entries.
// This is optional, if not specified, `range(arity)` is used.
// For a sequence, contains the index of the element.
// For a mapping, contains the key of the element.
// For a Custom type, contains the path entries returned by the `to_iterable` function.
py::object node_entries;

// Custom type registration. Must be null for non-custom types.
const PyTreeTypeRegistry::Registration *custom = nullptr;

Expand Down Expand Up @@ -176,6 +204,14 @@ class PyTreeSpec {
Span &leaves, // NOLINT
const std::optional<py::function> &leaf_predicate);

template <bool NoneIsLeaf, typename Span, typename Stack>
void FlattenIntoWithPathImpl(const py::handle &handle,
Span &leaves, // NOLINT
Span &paths, // NOLINT
Stack &stack, // NOLINT
const ssize_t &depth,
const std::optional<py::function> &leaf_predicate);

py::list FlattenUpToImpl(const py::handle &full_tree) const;

template <bool NoneIsLeaf>
Expand Down
1 change: 1 addition & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand Down
11 changes: 8 additions & 3 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
# pylint: disable=all
# isort: off

from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Sequence, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type

if TYPE_CHECKING:
from optree.typing import MetaData, Children, CustomTreeNode, PyTree, T, U
from optree.typing import Children, CustomTreeNode, MetaData, PyTree, T, U

version: int

Expand All @@ -28,6 +28,11 @@ def flatten(
leaf_predicate: Optional[Callable[[T], bool]] = None,
node_is_leaf: bool = False,
) -> Tuple[List[T], 'PyTreeSpec']: ...
def flatten_with_path(
tree: PyTree[T],
leaf_predicate: Optional[Callable[[T], bool]] = None,
node_is_leaf: bool = False,
) -> Tuple[List[Tuple[Any, ...]], List[T], 'PyTreeSpec']: ...
def all_leaves(iterable: Iterable[T], node_is_leaf: bool = False) -> bool: ...
def leaf(node_is_leaf: bool = False) -> 'PyTreeSpec': ...
def none(node_is_leaf: bool = False) -> 'PyTreeSpec': ...
Expand All @@ -53,7 +58,7 @@ class PyTreeSpec:
def __hash__(self) -> int: ...

def register_node(
type: Type[CustomTreeNode[T]],
cls: Type[CustomTreeNode[T]],
to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
) -> None: ...
8 changes: 8 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
tree_all,
tree_any,
tree_flatten,
tree_flatten_with_path,
tree_leaves,
tree_map,
tree_map_,
tree_map_with_path,
tree_map_with_path_,
tree_paths,
tree_reduce,
tree_replace_nones,
tree_structure,
Expand Down Expand Up @@ -60,12 +64,16 @@
'NONE_IS_NODE',
'NONE_IS_LEAF',
'tree_flatten',
'tree_flatten_with_path',
'tree_unflatten',
'tree_leaves',
'tree_structure',
'tree_paths',
'all_leaves',
'tree_map',
'tree_map_',
'tree_map_with_path',
'tree_map_with_path_',
'tree_reduce',
'tree_transpose',
'tree_replace_nones',
Expand Down
Loading