From 2725e4eb0d3c466fb2ebddb22f271646cbc588f2 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sat, 19 Oct 2024 14:47:48 +0200 Subject: [PATCH] dataclasses can now declare fields 'variables', 'values' --- phiml/math/_magic_ops.py | 28 +++++++++++++++++++++++----- phiml/math/magic.py | 6 +++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index 044cffc..caaf05d 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -2,7 +2,7 @@ import warnings from functools import partial from numbers import Number -from typing import TypeVar, Tuple, Set, Dict, Union, Optional, Sequence, Any +from typing import TypeVar, Tuple, Set, Dict, Union, Optional, Sequence, Any, get_origin, List, Iterable, get_args import dataclasses @@ -707,7 +707,11 @@ def variable_attributes(obj) -> Tuple[str, ...]: assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'" return result elif dataclasses.is_dataclass(obj): - return tuple([f.name for f in dataclasses.fields(obj)]) + if hasattr(obj, 'variables'): + result = obj.variables + assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'" + return result + return all_attributes(obj) else: raise ValueError(f"Not a PhiTreeNode: {type(obj).__name__}") @@ -718,7 +722,11 @@ def value_attributes(obj) -> Tuple[str, ...]: assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'" return result if dataclasses.is_dataclass(obj): - return tuple([f.name for f in dataclasses.fields(obj)]) + if hasattr(obj, 'values'): + result = obj.values + assert isinstance(result, tuple), f"dataclass.values must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'" + return result + return all_attributes(obj) raise ValueError(f"{type(obj).__name__} must implement '__value_attrs__()' or be a dataclass to be used with value functions.") @@ -731,7 +739,7 @@ def variable_values(obj) -> Tuple[str, ...]: return obj.__value_attrs__() # this takes care of dataclasses as well -def all_attributes(obj, assert_any=False) -> Sequence[str]: +def all_attributes(obj, assert_any=False) -> Tuple[str, ...]: if hasattr(obj, '__all_attrs__'): result = obj.__all_attrs__() assert isinstance(result, tuple), f"__value_attrs__ must return Tuple[str,...] but got '{type(result)}' from '{type(obj)}'" @@ -744,12 +752,22 @@ def all_attributes(obj, assert_any=False) -> Sequence[str]: if hasattr(obj, '__value_attrs__'): result.update(obj.__value_attrs__()) if dataclasses.is_dataclass(obj) and not hasattr(obj, '__variable_attrs__') and not hasattr(obj, '__value_attrs__'): - result.update([f.name for f in dataclasses.fields(obj)]) + result.update([f.name for f in dataclasses.fields(obj) if _is_child_field(f)]) if assert_any: assert result, f"{type(obj).__name__} is not a valid tree node because it has no tensor-like attributes." return tuple(sorted(result)) +def _is_child_field(field: dataclasses.Field): + origin_type = get_origin(field.type) + if origin_type in {list, List, tuple, Tuple, set, Set, Iterable}: + args = get_args(field.type) # The arguments passed to the generic (e.g., List[int] -> (int,)) + primitives = [a for a in args if a is not Ellipsis] if args else [field.type] + else: + primitives = [field.type] + return any(p not in (str,) for p in primitives) + + def replace(obj: PhiTreeNodeType, **updates) -> PhiTreeNodeType: """ Creates a copy of the given `phiml.math.magic.PhiTreeNode` with updated values as specified in `updates`. diff --git a/phiml/math/magic.py b/phiml/math/magic.py index d628ef6..fe6a129 100644 --- a/phiml/math/magic.py +++ b/phiml/math/magic.py @@ -372,6 +372,8 @@ def __value_attrs__(self) -> Tuple[str, ...]: - `jacobian()` - `custom_gradient()` + Dataclasses may instead declare the field `values: Tuple[str,...]` + Returns: `tuple` of `str` attributes. Calling `getattr(self, attr)` must return a `Tensor` or `PhiTreeNode` for all returned attributes. @@ -390,6 +392,8 @@ def __variable_attrs__(self) -> Tuple[str, ...]: - `jit_compile_linear()` - `stop_gradient()` + Dataclasses may instead declare the field `variables: Tuple[str,...]` + Returns: `tuple` of `str` attributes. Calling `getattr(self, attr)` must return a `Tensor` or `PhiTreeNode` for all returned attributes. @@ -404,7 +408,7 @@ def __all_attrs__(self) -> Tuple[str, ...]: The returned attributes are used to extract tensors for serializing and un-serializing the object. All names returned by `__variable_attrs__` and `__value_attrs__` must be included in this list. - If not implemented, the union of `__variable_attrs__` and `__value_attrs__` will be used instead. + If not implemented, the union of `__variable_attrs__` and `__value_attrs__` will be used instead, and dataclasses default to all fields possibly containing data. Returns: `tuple` of `str` attributes.