Skip to content

Commit

Permalink
refactor: use static raw pointers for global imports (metaopt#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 25, 2022
1 parent d33cc88 commit a5d93c3
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 81 deletions.
120 changes: 66 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<!-- markdownlint-disable html -->

# OpTree

![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen)
Expand All @@ -19,9 +21,9 @@ Optimized PyTree Utilities.
- [PyTrees](#pytrees)
- [Tree Nodes and Leaves](#tree-nodes-and-leaves)
- [Built-in PyTree Node Types](#built-in-pytree-node-types)
- [Registering a Custom Container-like Type as Non-leaf Nodes](#registering-a-custom-container-like-type-as-non-leaf-nodes)
- [Registering a Container-like Custom Type as Non-leaf Nodes](#registering-a-container-like-custom-type-as-non-leaf-nodes)
- [Notes about the PyTree Type Registry](#notes-about-the-pytree-type-registry)
- [`None` is non-leaf Node vs. `None` is Leaf](#none-is-non-leaf-node-vs-none-is-leaf)
- [`None` is Non-leaf Node vs. `None` is Leaf](#none-is-non-leaf-node-vs-none-is-leaf)
- [Key Ordering for Dictionaries](#key-ordering-for-dictionaries)
- [Benchmark](#benchmark)
- [Tree Flatten](#tree-flatten)
Expand Down Expand Up @@ -115,12 +117,12 @@ OpTree out-of-box supports the following Python container types in the registry:
which are considered non-leaf nodes in the tree.
Python objects that the type is not registered will be treated as leaf nodes.
The registration lookup uses the `is` operator to determine whether the type is matched.
So subclasses will need to explicitly register in the registration, otherwise, an object of that type will be considered as a leaf.
The `NoneType` is a special case discussed in section [`None` is non-leaf Node vs. `None` is Leaf](#none-is-non-leaf-node-vs-none-is-leaf).
So subclasses will need to explicitly register in the registry, otherwise, an object of that type will be considered as a leaf.
The [`NoneType`](https://docs.python.org/3/library/constants.html#None) is a special case discussed in section [`None` is non-leaf Node vs. `None` is Leaf](#none-is-non-leaf-node-vs-none-is-leaf).

#### Registering a Custom Container-like Type as Non-leaf Nodes
#### Registering a Container-like Custom Type as Non-leaf Nodes

A container-like Python type can be registered in the container registry with a pair of functions that specify:
A container-like Python type can be registered in the type registry with a pair of functions that specify:

- `flatten_func(container) -> (children, metadata, entries)`: convert an instance of the container type to a `(children, metadata, entries)` triple, where `children` is an iterable of subtrees and `entries` is an iterable of path entries of the container (e.g., indices or keys).
- `unflatten_func(metadata, children) -> container`: convert such a pair back to an instance of the container type.
Expand All @@ -131,30 +133,32 @@ The `entries` can be omitted (only returns a pair) or is optional to implement (

```python
# Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... # (set) -> (children, metadata, None)
... lambda s: (sorted(s), None, None),
... # (metadata, children) -> (set)
... lambda _, children: set(children),
... namespace='set',
... )
register_pytree_node(
set,
# (set) -> (children, metadata, None)
lambda s: (sorted(s), None, None),
# (metadata, children) -> (set)
lambda _, children: set(children),
namespace='set',
)

# Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... # (tensor) -> (children, metadata)
... flatten_func=lambda tensor: (
... (tensor.cpu().numpy(),),
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
... ),
... # (metadata, children) -> tensor
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
<class 'torch.Tensor'>
import torch

register_pytree_node(
torch.Tensor,
# (tensor) -> (children, metadata)
flatten_func=lambda tensor: (
(tensor.cpu().numpy(),),
dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
),
# (metadata, children) -> tensor
unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
namespace='torch2numpy',
)
```

```python
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
Expand Down Expand Up @@ -189,22 +193,24 @@ The `entries` can be omitted (only returns a pair) or is optional to implement (
Users can also extend the pytree registry by decorating the custom class and defining an instance method `tree_flatten` and a class method `tree_unflatten`.

```python
>>> from collections import UserDict
...
... @optree.register_pytree_node_class(namespace='mydict')
... class MyDict(UserDict):
... def tree_flatten(self): # -> (children, metadata, entries)
... reversed_keys = sorted(self.keys(), reverse=True)
... return (
... [self[key] for key in reversed_keys], # children
... reversed_keys, # metadata
... reversed_keys, # entries
... )
...
... @classmethod
... def tree_unflatten(cls, metadata, children):
... return cls(zip(metadata, children))
from collections import UserDict

@optree.register_pytree_node_class(namespace='mydict')
class MyDict(UserDict):
def tree_flatten(self): # -> (children, metadata, entries)
reversed_keys = sorted(self.keys(), reverse=True)
return (
[self[key] for key in reversed_keys], # children
reversed_keys, # metadata
reversed_keys, # entries
)

@classmethod
def tree_unflatten(cls, metadata, children):
return cls(zip(metadata, children))
```

```python
>>> tree = MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))

# Flatten without specifying the namespace
Expand Down Expand Up @@ -240,9 +246,9 @@ There are several key attributes of the pytree type registry:
prevent accidental collisions between different libraries that may register the same type.
```

2. **The elements in the type registry are immutable.** Users either cannot register the same type twice in the same namespace (i.e., update the type registry). Nor cannot remove a type from the type registry. To update the behavior of an already registered type, simply register it again with another `namespace`.
2. **The elements in the type registry are immutable.** Users can neither register the same type twice in the same namespace (i.e., update the type registry), nor remove a type from the type registry. To update the behavior of an already registered type, simply register it again with another `namespace`.

3. **Users cannot modify the behavior of already registered built-in types** listed [Built-in PyTree Node Types](#built-in-pytree-node-types), such as key order sorting for `dict` and `collections.defaultdict`.
3. **Users cannot modify the behavior of already registered built-in types** listed in [Built-in PyTree Node Types](#built-in-pytree-node-types), such as key order sorting for `dict` and `collections.defaultdict`.

4. **Inherited subclasses are not implicitly registered.** The registration lookup uses `type(obj) is registered_type` rather than `isinstance(obj, registered_type)`. Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement with [`metaclass`](https://docs.python.org/3/reference/datamodel.html#metaclasses) or [`__init_subclass__`](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation), for example:

Expand Down Expand Up @@ -271,23 +277,28 @@ There are several key attributes of the pytree type registry:
# Subclasses will be automatically registered in namespace 'mydict'
class MyAnotherDict(MyDict):
pass
```

tree = MyDict(b=4, a=(2, 3), c=MyAnotherDict({'d': 5, 'f': 6}))
optree.tree_flatten_with_path(tree, namespace='mydict')
# (
# [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
# [6, 5, 4, 2, 3],
# PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]), namespace='mydict')
# )
```python
>>> tree = MyDict(b=4, a=(2, 3), c=MyAnotherDict({'d': 5, 'f': 6}))
>>> optree.tree_flatten_with_path(tree, namespace='mydict')
(
[('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
[6, 5, 4, 2, 3],
PyTreeSpec(
CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]),
namespace='mydict'
)
)
```

### `None` is non-leaf Node vs. `None` is Leaf
### `None` is Non-leaf Node vs. `None` is Leaf

The [`None`](https://docs.python.org/3/library/constants.html#None) object is a special object in the Python language.
It serves some of the same purposes as `null` (a pointer does not point to anything) in other programming languages, which denotes a variable is empty or marks default parameters.
However, the `None` object is a singleton object rather than a pointer.
It may also serve as a sentinel value.
In addition, if a function has returned without any return value, it also implicitly returns the `None` object.
In addition, if a function has returned without any return value or the return statement is omitted, the function will also implicitly return the `None` object.

By default, the `None` object is considered a non-leaf node in the tree with arity 0, i.e., _**a non-leaf node that has no children**_.
This is like the behavior of an empty tuple.
Expand Down Expand Up @@ -315,7 +326,7 @@ Otherwise, the `None` object will remain in the tree specification (structure).
>>> import torch

>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters
>>> linear._parameters # a container has None
OrderedDict([
('weight', Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
Expand All @@ -332,6 +343,7 @@ OrderedDict([

>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType

>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict([
('weight', tensor([[0., 0., 0.],
Expand Down
55 changes: 28 additions & 27 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,34 @@ namespace py = pybind11;
using size_t = py::size_t;
using ssize_t = py::ssize_t;

#define PyCollectionsModule (*ImportCollections())
#define PyOrderedDictTypeObject (*ImportOrderedDict())
#define PyDefaultDictTypeObject (*ImportDefaultDict())
#define PyDequeTypeObject (*ImportDeque())

inline py::module_* ImportCollections() {
static auto collectionsUptr = std::make_unique<py::module_>(
py::reinterpret_borrow<py::module_>(py::module_::import("collections").release()));
return collectionsUptr.get();
}

inline py::object* ImportOrderedDict() {
static auto OrderedDictUptr = std::make_unique<py::object>(
py::reinterpret_borrow<py::object>(py::getattr(PyCollectionsModule, "OrderedDict")));
return OrderedDictUptr.get();
}

inline py::object* ImportDefaultDict() {
static auto defaultdictUptr = std::make_unique<py::object>(
py::reinterpret_borrow<py::object>(py::getattr(PyCollectionsModule, "defaultdict")));
return defaultdictUptr.get();
}

inline py::object* ImportDeque() {
static auto dequeUptr = std::make_unique<py::object>(
py::reinterpret_borrow<py::object>(py::getattr(PyCollectionsModule, "deque")));
return dequeUptr.get();
#define PyCollectionsModule (ImportCollections())
#define PyOrderedDictTypeObject (ImportOrderedDict())
#define PyDefaultDictTypeObject (ImportDefaultDict())
#define PyDequeTypeObject (ImportDeque())

inline const py::module_& ImportCollections() {
// NOTE: Use raw pointers to leak the memory intentionally to avoid py::object deallocation and
// garbage collection
static const py::module_* ptr = new py::module_{py::module_::import("collections")};
return *ptr;
}
inline const py::object& ImportOrderedDict() {
// NOTE: Use raw pointers to leak the memory intentionally to avoid py::object deallocation and
// garbage collection
static const py::object* ptr = new py::object{py::getattr(PyCollectionsModule, "OrderedDict")};
return *ptr;
}
inline const py::object& ImportDefaultDict() {
// NOTE: Use raw pointers to leak the memory intentionally to avoid py::object deallocation and
// garbage collection
static const py::object* ptr = new py::object{py::getattr(PyCollectionsModule, "defaultdict")};
return *ptr;
}
inline const py::object& ImportDeque() {
// NOTE: Use raw pointers to leak the memory intentionally to avoid py::object deallocation and
// garbage collection
static const py::object* ptr = new py::object{py::getattr(PyCollectionsModule, "deque")};
return *ptr;
}

template <typename T>
Expand Down

0 comments on commit a5d93c3

Please sign in to comment.