Skip to content

Commit

Permalink
dataclasses can now declare fields 'variables', 'values'
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 19, 2024
1 parent 50e69bc commit 2725e4e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
28 changes: 23 additions & 5 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import partial
from numbers import Number
from typing import TypeVar, Tuple, Set, Dict, Union, Optional, Sequence, Any
from typing import TypeVar, Tuple, Set, Dict, Union, Optional, Sequence, Any, get_origin, List, Iterable, get_args

import dataclasses

Expand Down Expand Up @@ -707,7 +707,11 @@ def variable_attributes(obj) -> Tuple[str, ...]:
assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'"
return result
elif dataclasses.is_dataclass(obj):
return tuple([f.name for f in dataclasses.fields(obj)])
if hasattr(obj, 'variables'):
result = obj.variables
assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'"
return result
return all_attributes(obj)
else:
raise ValueError(f"Not a PhiTreeNode: {type(obj).__name__}")

Expand All @@ -718,7 +722,11 @@ def value_attributes(obj) -> Tuple[str, ...]:
assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'"
return result
if dataclasses.is_dataclass(obj):
return tuple([f.name for f in dataclasses.fields(obj)])
if hasattr(obj, 'values'):
result = obj.values
assert isinstance(result, tuple), f"dataclass.values must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'"
return result
return all_attributes(obj)
raise ValueError(f"{type(obj).__name__} must implement '__value_attrs__()' or be a dataclass to be used with value functions.")


Expand All @@ -731,7 +739,7 @@ def variable_values(obj) -> Tuple[str, ...]:
return obj.__value_attrs__() # this takes care of dataclasses as well


def all_attributes(obj, assert_any=False) -> Sequence[str]:
def all_attributes(obj, assert_any=False) -> Tuple[str, ...]:
if hasattr(obj, '__all_attrs__'):
result = obj.__all_attrs__()
assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'"
Expand All @@ -744,12 +752,22 @@ def all_attributes(obj, assert_any=False) -> Sequence[str]:
if hasattr(obj, '__value_attrs__'):
result.update(obj.__value_attrs__())
if dataclasses.is_dataclass(obj) and not hasattr(obj, '__variable_attrs__') and not hasattr(obj, '__value_attrs__'):
result.update([f.name for f in dataclasses.fields(obj)])
result.update([f.name for f in dataclasses.fields(obj) if _is_child_field(f)])
if assert_any:
assert result, f"{type(obj).__name__} is not a valid tree node because it has no tensor-like attributes."
return tuple(sorted(result))


def _is_child_field(field: dataclasses.Field):
origin_type = get_origin(field.type)
if origin_type in {list, List, tuple, Tuple, set, Set, Iterable}:
args = get_args(field.type) # The arguments passed to the generic (e.g., List[int] -> (int,))
primitives = [a for a in args if a is not Ellipsis] if args else [field.type]
else:
primitives = [field.type]
return any(p not in (str,) for p in primitives)


def replace(obj: PhiTreeNodeType, **updates) -> PhiTreeNodeType:
"""
Creates a copy of the given `phiml.math.magic.PhiTreeNode` with updated values as specified in `updates`.
Expand Down
6 changes: 5 additions & 1 deletion phiml/math/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ def __value_attrs__(self) -> Tuple[str, ...]:
- `jacobian()`
- `custom_gradient()`
Dataclasses may instead declare the field `values: Tuple[str,...]`
Returns:
`tuple` of `str` attributes.
Calling `getattr(self, attr)` must return a `Tensor` or `PhiTreeNode` for all returned attributes.
Expand All @@ -390,6 +392,8 @@ def __variable_attrs__(self) -> Tuple[str, ...]:
- `jit_compile_linear()`
- `stop_gradient()`
Dataclasses may instead declare the field `variables: Tuple[str,...]`
Returns:
`tuple` of `str` attributes.
Calling `getattr(self, attr)` must return a `Tensor` or `PhiTreeNode` for all returned attributes.
Expand All @@ -404,7 +408,7 @@ def __all_attrs__(self) -> Tuple[str, ...]:
The returned attributes are used to extract tensors for serializing and un-serializing the object.
All names returned by `__variable_attrs__` and `__value_attrs__` must be included in this list.
If not implemented, the union of `__variable_attrs__` and `__value_attrs__` will be used instead.
If not implemented, the union of `__variable_attrs__` and `__value_attrs__` will be used instead, and dataclasses default to all fields possibly containing data.
Returns:
`tuple` of `str` attributes.
Expand Down

0 comments on commit 2725e4e

Please sign in to comment.