Skip to content

Commit

Permalink
Update map()
Browse files Browse the repository at this point in the history
* Add dims argument
* Add unwrap_scalars argument
* Move to _functional.py
* Fix stack() for strings
  • Loading branch information
holl- committed Oct 5, 2023
1 parent 90ccda9 commit f76e7fe
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 75 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
choose_backend_t as choose_backend, all_available, convert, seed, to_device,
native, numpy, reshaped_native, reshaped_tensor, reshaped_numpy, copy, native_call,
print_ as print,
map_ as map,
slice_off,
zeros, ones, fftfreq, random_normal, random_uniform, meshgrid, linspace, arange as range, range_tensor, # creation operators (use default backend)
zeros_like, ones_like,
Expand Down Expand Up @@ -91,6 +90,7 @@
iterate,
identity,
trace_check,
map_ as map,
)

from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu
Expand Down
59 changes: 55 additions & 4 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
import types
import warnings
from functools import wraps, partial
from typing import Tuple, Callable, Dict, Generic, List, TypeVar, Any, Set, Union
from typing import Tuple, Callable, Dict, Generic, List, TypeVar, Any, Set, Union, Optional

import numpy as np

from . import _ops as math
from ._magic_ops import stack
from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel
from ._magic_ops import stack, pack_dims, expand, unpack_dim
from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel, merge_shapes
from ._sparse import SparseCoordinateTensor
from ._tensors import Tensor, disassemble_tree, assemble_tree, disassemble_tensors, assemble_tensors, variable_attributes, wrap, specs_equal, equality_by_shape_and_value
from ._trace import ShiftLinTracer, matrix_from_function, LinearTraceInProgress
from ..backend import Backend, NUMPY
from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER
from .magic import PhiTreeNode
from .magic import PhiTreeNode, Shapable

X = TypeVar('X')
Y = TypeVar('Y')
Expand Down Expand Up @@ -1103,6 +1103,57 @@ def iterate(f: Callable,
raise ValueError(f"iterations must be an int or Shape but got {type(iterations)}")


def map_(function, *args, dims: Shape = None, range=range, unwrap_scalars=True, **kwargs) -> Union[None, Tensor, Tuple[Optional[Tensor]]]:
"""
Calls `function` on slices of the arguments and returns the stacked result.
Args:
function: Function to be called on slices of `args` and `kwargs`.
Must return one or multiple values that can be stacked.
`None` may be returned but if any return value is `None`, all calls to `function` must return `None` in that position.
*args: Positional arguments for `function`.
Values that are `phiml.math.magic.Sliceable` will be sliced along `dims`.
**kwargs: Keyword arguments for `function`.
Values that are `phiml.math.magic.Sliceable` will be sliced along `dims`.
dims: Dimensions which should be sliced.
`function` is called once for each element in `dims`, i.e. `dims.volume` times.
If `dims` is not specified, all dimensions from the `phiml.math.magic.Sliceable` values in `args` and `kwargs` will be mapped.
range: Optional range function. Can be used to generate `tqdm` output by passing `trange`.
unwrap_scalars: If `True`, passes the contents of scalar `Tensor`s instead of the tensor objects.
Returns:
`Tensor` of same shape as `value`.
"""
p_names = function_parameters(function)
all_args = {**kwargs, **{p_names[i]: v for i, v in enumerate(args)}}
sliceable_args = {k: v for k, v in all_args.items() if isinstance(v, Shapable)}
extra_args = {k: v for k, v in all_args.items() if not isinstance(v, Shapable)}
if dims is None:
dims = merge_shapes(*sliceable_args.values())
assert dims.volume > 0, f"map dims must have volume > 0 but got {dims}"
results = []
for _, idx in zip(range(dims.volume), dims.meshgrid()):
args = {k: v[idx] for k, v in sliceable_args.items()}
if unwrap_scalars:
args = {k: v.native() if isinstance(v, Tensor) else v for k, v in args.items()}
f_output = function(**args, **extra_args)
results.append(f_output)
if isinstance(results[0], tuple):
stacked: List[Optional[Tensor]] = []
for i in range(len(results[0])):
if any(r[i] is None for r in results):
assert all(r[i] is None for r in results), f"map function returned None for some elements, {results}"
stacked.append(None)
else:
stacked.append(math.stack([r[i] for r in results], dims))
return tuple(stacked)
else:
if any(r is None for r in results):
assert all(r is None for r in results), f"map function returned None for some elements, {results}"
return None
return stack(results, dims)


def identity(x):
"""
Identity function for one argument.
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def stack(values: Union[tuple, list, dict], dim: Shape, expand_values=False, **k
else:
values = [expand(v, all_dims.without(shape(v))) for v in values]
else:
all_batch_dims = merge_shapes(*[batch(v) for v in values_], allow_varying_sizes=True)
all_batch_dims = merge_shapes(*[shape(v).batch for v in values_], allow_varying_sizes=True)
if isinstance(values, dict):
values = {k: expand(v, all_batch_dims.without(shape(v))) for k, v in values.items()}
else:
Expand Down Expand Up @@ -277,7 +277,7 @@ def concat(values: Union[tuple, list], dim: Union[str, Shape], expand_values=Fal
assert dim in shape(v), f"dim must be present in the shapes of all values bot got value {type(v).__name__} with shape {shape(v)}"
for v in values[1:]:
assert set(non_batch(v).names) == set(non_batch(values[0]).names), f"Concatenated values must have the same non-batch dimensions but got {non_batch(values[0])} and {non_batch(v)}"
all_batch_dims = merge_shapes(*[batch(v) for v in values])
all_batch_dims = merge_shapes(*[shape(v).batch for v in values])
values = [expand(v, all_batch_dims) for v in values]
# --- First try __concat__ ---
for v in values:
Expand Down
44 changes: 0 additions & 44 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,50 +507,6 @@ def _includes_slice(s_dict: dict, dim: Shape, i: int):
return i in indices


def map_(function, *values, range=range, **kwargs) -> Union[Tensor, None]:
"""
Calls `function` on all elements of `values`.
Args:
function: Function to be called on single elements contained in `value`. Must return a value that can be stored in tensors.
*values: `Tensors` containing positional arguments for `function`.
Number of tensors must match `function` signature.
range: Range function. Can be used to generate tqdm output by passing `trange`.
**kwargs: Non-`Tensor` keyword arguments for `function`.
Their shapes are not broadcast with the positional arguments.
Returns:
`Tensor` of same shape as `value`.
"""
if not values:
return function(**kwargs)
values = [v if isinstance(v, Shapable) else wrap(v) for v in values]
shape = merge_shapes(*[v.shape for v in values])
flat = [pack_dims(expand(v, shape), shape, channel(flat=shape.volume)) for v in values]
result = []
results = None
for _, items in zip(range(flat[0].flat.size_or_1), zip(*flat)):
f_output = function(*items, **kwargs)
if isinstance(f_output, tuple):
if results is None:
results = [[] for _ in f_output]
for result_i, output_i in zip(results, f_output):
result_i.append(output_i)
else:
result.append(f_output)
if results is None:
if any(r is None for r in result):
assert all(r is None for r in result), f"map function returned None for some elements, {result}"
return None
return unpack_dim(stack(result, channel('_c')) if isinstance(result, Shapable) else wrap(result, channel('_c')), '_c', shape)
else:
for i, result_i in enumerate(results):
if any(r is None for r in result_i):
assert all(r is None for r in result_i), f"map function returned None for some elements at output index {i}, {result_i}"
results[i] = None
return tuple([unpack_dim(stack(result_i, channel('_c')) if isinstance(result_i, Shapable) else wrap(result_i, channel('_c')), '_c', shape) for result_i in results])


def _initialize(uniform_initializer, shapes: Tuple[Shape]) -> Tensor:
shape = concat_shapes(*shapes)
assert shape.well_defined, f"When creating a Tensor, shape needs to have definitive sizes but got {shape}"
Expand Down
35 changes: 34 additions & 1 deletion tests/commit/math/test__functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from phiml import math
from phiml.backend import Backend
from phiml.backend._backend import init_installed_backends
from phiml.math import tensor, spatial, batch, channel
from phiml.math import tensor, spatial, batch, channel, wrap

BACKENDS = init_installed_backends()

Expand Down Expand Up @@ -224,3 +224,36 @@ def f(x, aux):
f(x0, aux0)
self.assertTrue(math.trace_check(f, x0, aux0)[0])
self.assertTrue(math.trace_check(f, x=x0, aux=aux0)[0])

def test_map(self):
F_CALLS = []
def f(x, y):
F_CALLS.append((x, y))
return x + y
x = wrap((0, 1), spatial('x'))
y = wrap((2, 4), spatial('y'))
math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x, y))
self.assertEqual(4, len(F_CALLS), msg=F_CALLS)
F_CALLS.clear()
math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x=x, y=y))
self.assertEqual(4, len(F_CALLS), msg=F_CALLS)
F_CALLS.clear()
math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x, y, dims=x.shape, unwrap_scalars=False))
self.assertEqual(2, len(F_CALLS), msg=F_CALLS)
F_CALLS.clear()

def test_map_layout(self):
l = math.layout('loss', math.EMPTY_SHAPE)
a = math.layout([[0, 1], [2, 3]], spatial('x,y'))
loss4 = math.map(lambda l, a: l, l, a)
for l4 in loss4:
self.assertEqual('loss', l4)

def test_map_multi_output(self):
def f(x, y):
return x + y, x - y
x = wrap((0, 1), spatial('x'))
y = wrap((2, 4), spatial('y'))
r_x, r_y = math.map(f, x, y)
math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), r_x)
math.assert_close(wrap([(-2, -4), (-1, -3)], spatial('x,y')), r_y)
23 changes: 0 additions & 23 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,29 +789,6 @@ def test_fit_hyperplane(self):
assert_close(w, wrap([(0.8, -1), (-0.8, 0)], channel(y), channel(x)), abs_tolerance=1e-3)
assert_close(b, wrap((1, 0), channel(y)), abs_tolerance=1e-3)

def test_map(self):
def f(x, y):
return x + y
x = wrap((0, 1), spatial('x'))
y = wrap((2, 4), spatial('y'))
math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), math.map(f, x, y))

def test_map_layout(self):
l = math.layout('loss', math.EMPTY_SHAPE)
a = math.layout([[0, 1], [2, 3]], spatial('x,y'))
loss4 = math.map(lambda l, a: l, l, a)
for l4 in loss4:
self.assertEqual('loss', l4)

def test_map_multi_output(self):
def f(x, y):
return x + y, x - y
x = wrap((0, 1), spatial('x'))
y = wrap((2, 4), spatial('y'))
r_x, r_y = math.map(f, x, y)
math.assert_close(wrap([(2, 4), (3, 5)], spatial('x,y')), r_x)
math.assert_close(wrap([(-2, -4), (-1, -3)], spatial('x,y')), r_y)

def test_to_device(self):
for backend in BACKENDS:
with backend:
Expand Down

0 comments on commit f76e7fe

Please sign in to comment.