Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(treespec): update string representation for OrderedDict #133

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.10
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.12"
update-environment: true

- name: Upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Update string representation for `OrderedDict` by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/optree/pull/133).

### Fixed

Expand Down
38 changes: 19 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -433,39 +433,39 @@ While flattening a tree, it will remain in the tree structure definitions rather
```

OpTree provides a keyword argument `none_is_leaf` to determine whether to consider the `None` object as a leaf, like other opaque objects.
If `none_is_leaf=True`, the `None` object will place in the leaves list.
If `none_is_leaf=True`, the `None` object will be placed in the leaves list.
Otherwise, the `None` object will remain in the tree specification (structure).

```python
>>> import torch

>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters # a container has None
OrderedDict([
('weight', Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
[-0.4876, -0.3142, 0.1785]], requires_grad=True)),
('bias', None)
])
OrderedDict({
'weight': Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
[-0.4876, -0.3142, 0.1785]], requires_grad=True),
'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', None)
])
OrderedDict({
'weight': tensor([[0., 0., 0.],
[0., 0., 0.]]),
'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
Traceback (most recent call last):
...
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType

>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', 0)
])
OrderedDict({
'weight': tensor([[0., 0., 0.],
[0., 0., 0.]]),
'bias': 0
})
```

### Key Ordering for Dictionaries
Expand All @@ -489,9 +489,9 @@ If users want to keep the values in the insertion order in pytree traversal, the
>>> OrderedDict([('a', [1, 2]), ('b', [3])]) == OrderedDict([('b', [3]), ('a', [1, 2])])
False
>>> optree.tree_flatten(OrderedDict([('a', [1, 2]), ('b', [3])]))
([1, 2, 3], PyTreeSpec(OrderedDict([('a', [*, *]), ('b', [*])])))
([1, 2, 3], PyTreeSpec(OrderedDict({'a': [*, *], 'b': [*]})))
>>> optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1, 2])]))
([3, 1, 2], PyTreeSpec(OrderedDict([('b', [*]), ('a', [*, *])])))
([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]})))
```

**Since OpTree v0.9.0, the key order of the reconstructed output dictionaries from `tree_unflatten` is guaranteed to be consistent with the key order of the input dictionaries in `tree_flatten`.**
Expand Down
36 changes: 24 additions & 12 deletions optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,33 @@ def tree_ravel(
... 'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,))
... },
... }
>>> tree
{'layer1': {'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)},
'layer2': {'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)}}
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)
},
'layer2': {
'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
Array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
>>> unravel_func(flat)
{'layer1': {'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)},
'layer2': {'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)}}
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)
},
'layer2': {
'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)
}
}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
Expand Down
36 changes: 24 additions & 12 deletions optree/integration/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,33 @@ def tree_ravel(
... 'bias': np.arange(10, 11, dtype=np.float32).reshape((1,))
... },
... }
>>> tree
{'layer1': {'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)},
'layer2': {'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)}}
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)
},
'layer2': {
'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
>>> unravel_func(flat)
{'layer1': {'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)},
'layer2': {'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)}}
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)
},
'layer2': {
'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)
}
}

Args:
tree (pytree): a pytree of arrays and scalars to ravel.
Expand Down
36 changes: 24 additions & 12 deletions optree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,33 @@ def tree_ravel(
... 'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,))
... },
... }
>>> tree
{'layer1': {'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)},
'layer2': {'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)}}
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)
},
'layer2': {
'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
tensor([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=torch.float64)
>>> unravel_func(flat)
{'layer1': {'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)},
'layer2': {'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)}}
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)
},
'layer2': {
'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)
}
}
Args:
tree (pytree): a pytree of tensors to ravel.
Expand Down
32 changes: 16 additions & 16 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ def tree_flatten(
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[2, 3, 4, 1, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}))
)
>>> tree_flatten(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
(
[2, 3, 4, 1, None, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)
)

Args:
Expand Down Expand Up @@ -233,13 +233,13 @@ def tree_flatten_with_path(
(
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('d',)],
[2, 3, 4, 1, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}))
)
>>> tree_flatten_with_path(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
(
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)],
[2, 3, 4, 1, None, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)
)

Args:
Expand Down Expand Up @@ -951,7 +951,7 @@ def tree_transpose_map_with_path(
>>> tree_transpose_map_with_path( # doctest: +IGNORE_WHITESPACE
... lambda p, x: {'path': p, 'value': x},
... tree,
... inner_treespec=tree_structure({'path': 0, 'value': 0})),
... inner_treespec=tree_structure({'path': 0, 'value': 0}),
... )
{
'path': {'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), 'a': ('a',), 'c': (('c', 0), ('c', 1))},
Expand Down Expand Up @@ -1694,7 +1694,7 @@ def tree_max(
>>> tree_max({})
Traceback (most recent call last):
...
ValueError: max() arg is an empty sequence
ValueError: max() iterable argument is empty
>>> tree_max({}, default=0)
0
>>> tree_max({'x': 0, 'y': (2, 1)})
Expand All @@ -1704,15 +1704,15 @@ def tree_max(
>>> tree_max({'a': None}) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: max() arg is an empty sequence
ValueError: max() iterable argument is empty
>>> tree_max({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default
0
>>> tree_max({'a': None}, none_is_leaf=True)
None
>>> tree_max(None) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: max() arg is an empty sequence
ValueError: max() iterable argument is empty
>>> tree_max(None, default=0)
0
>>> tree_max(None, none_is_leaf=True)
Expand Down Expand Up @@ -1789,7 +1789,7 @@ def tree_min(
>>> tree_min({})
Traceback (most recent call last):
...
ValueError: min() arg is an empty sequence
ValueError: min() iterable argument is empty
>>> tree_min({}, default=0)
0
>>> tree_min({'x': 0, 'y': (2, 1)})
Expand All @@ -1799,15 +1799,15 @@ def tree_min(
>>> tree_min({'a': None}) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: min() arg is an empty sequence
ValueError: min() iterable argument is empty
>>> tree_min({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default
0
>>> tree_min({'a': None}, none_is_leaf=True)
None
>>> tree_min(None) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: min() arg is an empty sequence
ValueError: min() iterable argument is empty
>>> tree_min(None, default=0)
0
>>> tree_min(None, none_is_leaf=True)
Expand Down Expand Up @@ -2455,15 +2455,15 @@ def treespec_ordereddict(
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.

>>> treespec_ordereddict({'a': treespec_leaf(), 'b': treespec_leaf()})
PyTreeSpec(OrderedDict([('a', *), ('b', *)]))
PyTreeSpec(OrderedDict({'a': *, 'b': *}))
>>> treespec_ordereddict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())])
PyTreeSpec(OrderedDict([('b', *), ('c', *), ('a', None)]))
PyTreeSpec(OrderedDict({'b': *, 'c': *, 'a': None}))
>>> treespec_ordereddict()
PyTreeSpec(OrderedDict([]))
PyTreeSpec(OrderedDict())
>>> treespec_ordereddict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()]))
PyTreeSpec(OrderedDict([('a', *), ('b', (*, *))]))
PyTreeSpec(OrderedDict({'a': *, 'b': (*, *)}))
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2])})
PyTreeSpec(OrderedDict([('a', *), ('b', [*, *])]))
PyTreeSpec(OrderedDict({'a': *, 'b': [*, *]}))
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)})
Traceback (most recent call last):
...
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ docs = [
]
benchmark = [
"jax[cpu] >= 0.4.6, < 0.5.0a0",
"torch >= 2.0, < 2.1.0a0",
"torch >= 2.0, < 2.3.0a0",
"torchvision",
"dm-tree >= 0.1, < 0.2.0a0",
"pandas",
Expand Down
Loading
Loading