Skip to content

Commit

Permalink
add forced numpy and tensor backend
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlao committed Jul 3, 2024
1 parent d792485 commit efb65cb
Showing 1 changed file with 134 additions and 28 deletions.
162 changes: 134 additions & 28 deletions camtools/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ def wrapper(*args, **kwargs):

def tensor_numpy_backend(func):
"""
Decorator to automatically convert input tensors to numpy arrays if they are
annotated as such using jaxtyping, regardless of their original backend.
Run this function by first converting its input tensors to numpy arrays.
Only jaxtyping-annotated tensors will be processed. This wrapper shall be
used if the internal implementation is numpy-only or if we expect to return
numpy arrays.
Behavior:
1. Only converts arguments that are annotated explicitly with a jaxtyping
Expand All @@ -300,49 +302,153 @@ def tensor_numpy_backend(func):
will be converted to numpy arrays if not already in that format.
"""

def _convert_to_numpy(item):
def _convert_to_numpy(arg):
"""
Recursively convert tensors to numpy arrays based on specified type.
Convert an argument to a numpy array if it is a tensor or a list of
tensor-like values.
"""
if isinstance(arg, np.ndarray):
return arg
elif is_torch_available() and isinstance(arg, torch.Tensor):
return arg.cpu().numpy()
elif isinstance(arg, list):
return np.array(arg)
else:
raise TypeError(f"Unsupported type {type(arg)} for conversion to numpy.")

Only handles list, numpy, and torch tensor types. Other types are
returned as is.
def _apply_conversion(args, kwargs, type_hints):
"""
if isinstance(item, np.ndarray):
return item
Apply numpy conversion to arguments based on their type annotations.
"""
new_args = []
for arg, hint in zip(args, type_hints):
if inspect.isclass(hint) and issubclass(hint, jaxtyping.AbstractArray):
new_args.append(_convert_to_numpy(arg))
else:
new_args.append(arg)

if is_torch_available() and isinstance(item, torch.Tensor):
return item.detach().cpu().numpy()
new_kwargs = {}
for key, arg in kwargs.items():
hint = type_hints.get(key)
if inspect.isclass(hint) and issubclass(hint, jaxtyping.AbstractArray):
new_kwargs[key] = _convert_to_numpy(arg)
else:
new_kwargs[key] = arg

return new_args, new_kwargs

@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

if isinstance(item, list):
return [_convert_to_numpy(i) for i in item]
type_hints = typing.get_type_hints(func)
new_args, new_kwargs = _apply_conversion(
bound_args.args, bound_args.kwargs, type_hints
)

raise TypeError(f"Unsupported type {type(item)} for conversion to numpy.")
# Manage backend and suppress warnings
stashed_backend = ivy.current_backend()
ivy.set_backend("numpy")

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with ivy.ArrayMode(False):
result = func(*new_args, **new_kwargs)

# Reset backend after function execution
ivy.set_backend(stashed_backend)

return result

return wrapper


def tensor_torch_backend(func):
"""
Run this function by first converting its input tensors to torch tensors.
Only jaxtyping-annotated tensors will be processed. This wrapper shall be
used if the internal implementation is torch-only or if we expect to return
torch tensors.
Behavior:
1. Only converts arguments that are annotated explicitly with a jaxtyping
tensor type. If the type hint is a container of tensors, the conversion
will not be performed.
2. Supports conversion of lists into torch tensors if they are intended to be
tensors, according to the function's type annotations.
3. The conversion is applied to top-level arguments and does not recursively
convert tensors within nested custom types (e.g., custom classes
containing tensors).
4. This decorator is particularly useful for functions requiring consistent
tensor handling specifically with torch, ensuring compatibility and
simplifying operations that depend on torch's functionality.
Note:
- The decorator inspects type annotations and applies conversions where
specified.
- Lists of tensors or tensors within lists annotated as tensors
will be converted to torch tensors if not already in that format.
"""

def _convert_to_torch(arg):
"""
Convert an argument to a torch tensor if it is a tensor or a list of
tensor-like values.
"""
if isinstance(arg, torch.Tensor):
return arg
elif isinstance(arg, np.ndarray):
return torch.from_numpy(arg)
elif isinstance(arg, list):
return torch.tensor(arg)
else:
raise TypeError(f"Unsupported type {type(arg)} for conversion to torch.")

def _apply_conversion(args, kwargs, type_hints):
"""
Apply torch conversion to arguments based on their type annotations.
"""
new_args = []
for arg, hint in zip(args, type_hints):
if inspect.isclass(hint) and issubclass(hint, jaxtyping.AbstractArray):
new_args.append(_convert_to_torch(arg))
else:
new_args.append(arg)

new_kwargs = {}
for key, arg in kwargs.items():
hint = type_hints.get(key)
if inspect.isclass(hint) and issubclass(hint, jaxtyping.AbstractArray):
new_kwargs[key] = _convert_to_torch(arg)
else:
new_kwargs[key] = arg

return new_args, new_kwargs

@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

type_hints = typing.get_type_hints(func)
new_args, new_kwargs = _apply_conversion(
bound_args.args, bound_args.kwargs, type_hints
)

# Manage backend and suppress warnings
stashed_backend = ivy.current_backend()
ivy.set_backend("torch")

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with ivy.ArrayMode(False):
# - Convert list -> numpy array if the type hint is a tensor
# - Convert torch -> numpy array if the type hint is a tensor and
# the input is a torch tensor
for arg_name, arg in bound_args.arguments.items():
if (
arg_name in typing.get_type_hints(func)
and inspect.isclass(typing.get_type_hints(func)[arg_name])
and issubclass(
typing.get_type_hints(func)[arg_name],
jaxtyping.AbstractArray,
)
):
bound_args.arguments[arg_name] = _convert_to_numpy(arg)
result = func(*new_args, **new_kwargs)

# Call the function
result = func(*bound_args.args, **bound_args.kwargs)
# Reset backend after function execution
ivy.set_backend(stashed_backend)

return result

Expand Down

0 comments on commit efb65cb

Please sign in to comment.