From f76e7fe66363fa5b330a4f4b04192e62fb413f4e Mon Sep 17 00:00:00 2001 From: holl- Date: Thu, 5 Oct 2023 14:31:59 +0200 Subject: [PATCH] Update map() * Add dims argument * Add unwrap_scalars argument * Move to _functional.py * Fix stack() for strings --- phiml/math/__init__.py | 2 +- phiml/math/_functional.py | 59 +++++++++++++++++++++++++-- phiml/math/_magic_ops.py | 4 +- phiml/math/_ops.py | 44 -------------------- tests/commit/math/test__functional.py | 35 +++++++++++++++- tests/commit/math/test__ops.py | 23 ----------- 6 files changed, 92 insertions(+), 75 deletions(-) diff --git a/phiml/math/__init__.py b/phiml/math/__init__.py index 04496334..1d62921a 100644 --- a/phiml/math/__init__.py +++ b/phiml/math/__init__.py @@ -42,7 +42,6 @@ choose_backend_t as choose_backend, all_available, convert, seed, to_device, native, numpy, reshaped_native, reshaped_tensor, reshaped_numpy, copy, native_call, print_ as print, - map_ as map, slice_off, zeros, ones, fftfreq, random_normal, random_uniform, meshgrid, linspace, arange as range, range_tensor, # creation operators (use default backend) zeros_like, ones_like, @@ -91,6 +90,7 @@ iterate, identity, trace_check, + map_ as map, ) from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index 87bb98e5..6a32b3aa 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -2,19 +2,19 @@ import types import warnings from functools import wraps, partial -from typing import Tuple, Callable, Dict, Generic, List, TypeVar, Any, Set, Union +from typing import Tuple, Callable, Dict, Generic, List, TypeVar, Any, Set, Union, Optional import numpy as np from . import _ops as math -from ._magic_ops import stack -from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel +from ._magic_ops import stack, pack_dims, expand, unpack_dim +from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel, merge_shapes from ._sparse import SparseCoordinateTensor from ._tensors import Tensor, disassemble_tree, assemble_tree, disassemble_tensors, assemble_tensors, variable_attributes, wrap, specs_equal, equality_by_shape_and_value from ._trace import ShiftLinTracer, matrix_from_function, LinearTraceInProgress from ..backend import Backend, NUMPY from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER -from .magic import PhiTreeNode +from .magic import PhiTreeNode, Shapable X = TypeVar('X') Y = TypeVar('Y') @@ -1103,6 +1103,57 @@ def iterate(f: Callable, raise ValueError(f"iterations must be an int or Shape but got {type(iterations)}") +def map_(function, *args, dims: Shape = None, range=range, unwrap_scalars=True, **kwargs) -> Union[None, Tensor, Tuple[Optional[Tensor]]]: + """ + Calls `function` on slices of the arguments and returns the stacked result. + + Args: + function: Function to be called on slices of `args` and `kwargs`. + Must return one or multiple values that can be stacked. + `None` may be returned but if any return value is `None`, all calls to `function` must return `None` in that position. + *args: Positional arguments for `function`. + Values that are `phiml.math.magic.Sliceable` will be sliced along `dims`. + **kwargs: Keyword arguments for `function`. + Values that are `phiml.math.magic.Sliceable` will be sliced along `dims`. + dims: Dimensions which should be sliced. + `function` is called once for each element in `dims`, i.e. `dims.volume` times. + If `dims` is not specified, all dimensions from the `phiml.math.magic.Sliceable` values in `args` and `kwargs` will be mapped. + range: Optional range function. Can be used to generate `tqdm` output by passing `trange`. + unwrap_scalars: If `True`, passes the contents of scalar `Tensor`s instead of the tensor objects. + + Returns: + `Tensor` of same shape as `value`. + """ + p_names = function_parameters(function) + all_args = {**kwargs, **{p_names[i]: v for i, v in enumerate(args)}} + sliceable_args = {k: v for k, v in all_args.items() if isinstance(v, Shapable)} + extra_args = {k: v for k, v in all_args.items() if not isinstance(v, Shapable)} + if dims is None: + dims = merge_shapes(*sliceable_args.values()) + assert dims.volume > 0, f"map dims must have volume > 0 but got {dims}" + results = [] + for _, idx in zip(range(dims.volume), dims.meshgrid()): + args = {k: v[idx] for k, v in sliceable_args.items()} + if unwrap_scalars: + args = {k: v.native() if isinstance(v, Tensor) else v for k, v in args.items()} + f_output = function(**args, **extra_args) + results.append(f_output) + if isinstance(results[0], tuple): + stacked: List[Optional[Tensor]] = [] + for i in range(len(results[0])): + if any(r[i] is None for r in results): + assert all(r[i] is None for r in results), f"map function returned None for some elements, {results}" + stacked.append(None) + else: + stacked.append(math.stack([r[i] for r in results], dims)) + return tuple(stacked) + else: + if any(r is None for r in results): + assert all(r is None for r in results), f"map function returned None for some elements, {results}" + return None + return stack(results, dims) + + def identity(x): """ Identity function for one argument. diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index ad1114d1..e80ae9f0 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -160,7 +160,7 @@ def stack(values: Union[tuple, list, dict], dim: Shape, expand_values=False, **k else: values = [expand(v, all_dims.without(shape(v))) for v in values] else: - all_batch_dims = merge_shapes(*[batch(v) for v in values_], allow_varying_sizes=True) + all_batch_dims = merge_shapes(*[shape(v).batch for v in values_], allow_varying_sizes=True) if isinstance(values, dict): values = {k: expand(v, all_batch_dims.without(shape(v))) for k, v in values.items()} else: @@ -277,7 +277,7 @@ def concat(values: Union[tuple, list], dim: Union[str, Shape], expand_values=Fal assert dim in shape(v), f"dim must be present in the shapes of all values bot got value {type(v).__name__} with shape {shape(v)}" for v in values[1:]: assert set(non_batch(v).names) == set(non_batch(values[0]).names), f"Concatenated values must have the same non-batch dimensions but got {non_batch(values[0])} and {non_batch(v)}" - all_batch_dims = merge_shapes(*[batch(v) for v in values]) + all_batch_dims = merge_shapes(*[shape(v).batch for v in values]) values = [expand(v, all_batch_dims) for v in values] # --- First try __concat__ --- for v in values: diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 535b992b..d1d384c7 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -507,50 +507,6 @@ def _includes_slice(s_dict: dict, dim: Shape, i: int): return i in indices -def map_(function, *values, range=range, **kwargs) -> Union[Tensor, None]: - """ - Calls `function` on all elements of `values`. - - Args: - function: Function to be called on single elements contained in `value`. Must return a value that can be stored in tensors. - *values: `Tensors` containing positional arguments for `function`. - Number of tensors must match `function` signature. - range: Range function. Can be used to generate tqdm output by passing `trange`. - **kwargs: Non-`Tensor` keyword arguments for `function`. - Their shapes are not broadcast with the positional arguments. - - Returns: - `Tensor` of same shape as `value`. - """ - if not values: - return function(**kwargs) - values = [v if isinstance(v, Shapable) else wrap(v) for v in values] - shape = merge_shapes(*[v.shape for v in values]) - flat = [pack_dims(expand(v, shape), shape, channel(flat=shape.volume)) for v in values] - result = [] - results = None - for _, items in zip(range(flat[0].flat.size_or_1), zip(*flat)): - f_output = function(*items, **kwargs) - if isinstance(f_output, tuple): - if results is None: - results = [[] for _ in f_output] - for result_i, output_i in zip(results, f_output): - result_i.append(output_i) - else: - result.append(f_output) - if results is None: - if any(r is None for r in result): - assert all(r is None for r in result), f"map function returned None for some elements, {result}" - return None - return unpack_dim(stack(result, channel('_c')) if isinstance(result, Shapable) else wrap(result, channel('_c')), '_c', shape) - else: - for i, result_i in enumerate(results): - if any(r is None for r in result_i): - assert all(r is None for r in result_i), f"map function returned None for some elements at output index {i}, {result_i}" - results[i] = None - return tuple([unpack_dim(stack(result_i, channel('_c')) if isinstance(result_i, Shapable) else wrap(result_i, channel('_c')), '_c', shape) for result_i in results]) - - def _initialize(uniform_initializer, shapes: Tuple[Shape]) -> Tensor: shape = concat_shapes(*shapes) assert shape.well_defined, f"When creating a Tensor, shape needs to have definitive sizes but got {shape}" diff --git a/tests/commit/math/test__functional.py b/tests/commit/math/test__functional.py index fde06ff4..2fe315ca 100644 --- a/tests/commit/math/test__functional.py +++ b/tests/commit/math/test__functional.py @@ -5,7 +5,7 @@ from phiml import math from phiml.backend import Backend from phiml.backend._backend import init_installed_backends -from phiml.math import tensor, spatial, batch, channel +from phiml.math import tensor, spatial, batch, channel, wrap BACKENDS = init_installed_backends() @@ -224,3 +224,36 @@ def f(x, aux): f(x0, aux0) self.assertTrue(math.trace_check(f, x0, aux0)[0]) self.assertTrue(math.trace_check(f, x=x0, aux=aux0)[0]) + + def test_map(self): + F_CALLS = [] + def f(x, y): + F_CALLS.append((x, y)) + return x + y + x = wrap((0, 1), spatial('x')) + y = wrap((2, 4), spatial('y')) + math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x, y)) + self.assertEqual(4, len(F_CALLS), msg=F_CALLS) + F_CALLS.clear() + math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x=x, y=y)) + self.assertEqual(4, len(F_CALLS), msg=F_CALLS) + F_CALLS.clear() + math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x, y, dims=x.shape, unwrap_scalars=False)) + self.assertEqual(2, len(F_CALLS), msg=F_CALLS) + F_CALLS.clear() + + def test_map_layout(self): + l = math.layout('loss', math.EMPTY_SHAPE) + a = math.layout([[0, 1], [2, 3]], spatial('x,y')) + loss4 = math.map(lambda l, a: l, l, a) + for l4 in loss4: + self.assertEqual('loss', l4) + + def test_map_multi_output(self): + def f(x, y): + return x + y, x - y + x = wrap((0, 1), spatial('x')) + y = wrap((2, 4), spatial('y')) + r_x, r_y = math.map(f, x, y) + math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), r_x) + math.assert_close(wrap([(-2, -4), (-1, -3)], spatial('x,y')), r_y) diff --git a/tests/commit/math/test__ops.py b/tests/commit/math/test__ops.py index 6e5caad7..caf21aa7 100644 --- a/tests/commit/math/test__ops.py +++ b/tests/commit/math/test__ops.py @@ -789,29 +789,6 @@ def test_fit_hyperplane(self): assert_close(w, wrap([(0.8, -1), (-0.8, 0)], channel(y), channel(x)), abs_tolerance=1e-3) assert_close(b, wrap((1, 0), channel(y)), abs_tolerance=1e-3) - def test_map(self): - def f(x, y): - return x + y - x = wrap((0, 1), spatial('x')) - y = wrap((2, 4), spatial('y')) - math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x, y)) - - def test_map_layout(self): - l = math.layout('loss', math.EMPTY_SHAPE) - a = math.layout([[0, 1], [2, 3]], spatial('x,y')) - loss4 = math.map(lambda l, a: l, l, a) - for l4 in loss4: - self.assertEqual('loss', l4) - - def test_map_multi_output(self): - def f(x, y): - return x + y, x - y - x = wrap((0, 1), spatial('x')) - y = wrap((2, 4), spatial('y')) - r_x, r_y = math.map(f, x, y) - math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), r_x) - math.assert_close(wrap([(-2, -4), (-1, -3)], spatial('x,y')), r_y) - def test_to_device(self): for backend in BACKENDS: with backend: