Skip to content

Commit

Permalink
feat: initial infrastructures for multi-backend support (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlao committed Jul 3, 2024
1 parent 6eca218 commit 798362c
Show file tree
Hide file tree
Showing 6 changed files with 524 additions and 2 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
- name: Run unit tests
- name: Run unit tests (numpy)
run: |
pytest
- name: Install dependencies with torch
run: |
pip install -e .[torch]
- name: Run unit tests (numpy + torch)
run: |
pytest
3 changes: 2 additions & 1 deletion camtools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import artifact
from . import backend
from . import camera
from . import colmap
from . import colormap
Expand All @@ -14,9 +15,9 @@
from . import sanity
from . import solver
from . import transform
from . import typing
from . import util


try:
# Python >= 3.8
from importlib.metadata import version
Expand Down
84 changes: 84 additions & 0 deletions camtools/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import typing
import warnings
from functools import lru_cache, wraps
from inspect import signature
from typing import Literal

import ivy
import jaxtyping

# Internally use "from camtools.backend import ivy" to make sure ivy is imported
# after the warnings filter is set. This is a temporary workaround to suppress
# the deprecation warning from numpy 2.0.
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=".*numpy.core.numeric is deprecated.*",
)
import ivy

_default_backend = "numpy"


@lru_cache(maxsize=None)
def is_torch_available():
try:
import torch

return True
except ImportError:
return False


def set_backend(backend: Literal["numpy", "torch"]) -> None:
"""
Set the default backend for camtools.
"""
global _default_backend
_default_backend = backend


def get_backend() -> str:
"""
Get the default backend for camtools.
"""
return _default_backend


def with_native_backend(func):
"""
1. Enable default camtools backend.
2. Return native backend array (setting array mode to False).
3. Converts lists to tensors if the type hint is a tensor.
"""

@wraps(func)
def wrapper(*args, **kwargs):
og_backend = ivy.current_backend()
ct_backend = get_backend()
ivy.set_backend(ct_backend)

# Unpack args and type hints
sig = signature(func)
bound_args = sig.bind(*args, **kwargs)
arg_name_to_hint = typing.get_type_hints(func)

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with ivy.ArrayMode(False):
# Convert list -> native tensor if the type hint is a tensor
for arg_name, arg in bound_args.arguments.items():
if arg_name in arg_name_to_hint and issubclass(
arg_name_to_hint[arg_name], jaxtyping.AbstractArray
):
if isinstance(arg, list):
bound_args.arguments[arg_name] = ivy.native_array(arg)

# Call the function
result = func(*bound_args.args, **bound_args.kwargs)
finally:
ivy.set_backend(og_backend)
return result

return wrapper
164 changes: 164 additions & 0 deletions camtools/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import typing
from distutils.version import LooseVersion
from functools import wraps
from inspect import signature
from typing import Any, Tuple, Union

import jaxtyping
import numpy as np
import pkg_resources

from . import backend

try:
# Try to import _FixedDim and _NamedDim from jaxtyping. This are internal
# classes where the APIs are not stable. There are no guarantees that these
# classes will be available in future versions of jaxtyping.
_jaxtyping_version = pkg_resources.get_distribution("jaxtyping").version
if LooseVersion(_jaxtyping_version) >= LooseVersion("0.2.20"):
from jaxtyping._array_types import _FixedDim, _NamedDim
else:
from jaxtyping.array_types import _FixedDim, _NamedDim
except ImportError:
raise ImportError(
f"Failed to import _FixedDim and _NamedDim. with "
f"jaxtyping version {_jaxtyping_version}."
)


class Tensor:
"""
An abstract tensor type for type hinting only.
Typically np.ndarray or torch.Tensor is supported.
"""

pass


def _dtype_to_str(dtype):
"""
Convert numpy or torch dtype to string
- "bool"
- "bool_"
- "uint4"
- "uint8"
- "uint16"
- "uint32"
- "uint64"
- "int4"
- "int8"
- "int16"
- "int32"
- "int64"
- "bfloat16"
- "float16"
- "float32"
- "float64"
- "complex64"
- "complex128"
"""
if isinstance(dtype, np.dtype):
return dtype.name

if backend.is_torch_available():
import torch

if isinstance(dtype, torch.dtype):
return str(dtype).split(".")[1]

return ValueError(f"Unknown dtype {dtype}.")


def _shape_from_dims(
dims: Tuple[
Union[_FixedDim, _NamedDim],
...,
]
) -> Tuple[Union[int, None], ...]:
shape = []
for dim in dims:
if isinstance(dim, _FixedDim):
shape.append(dim.size)
elif isinstance(dim, _NamedDim):
shape.append(None)
return tuple(shape)


def _is_shape_compatible(
arg_shape: Tuple[Union[int, None], ...],
gt_shape: Tuple[Union[int, None], ...],
) -> bool:
if len(arg_shape) != len(gt_shape):
return False

return all(
arg_dim == gt_dim or gt_dim is None
for arg_dim, gt_dim in zip(arg_shape, gt_shape)
)


def _assert_tensor_hint(
hint: jaxtyping.AbstractArray,
arg: Any,
arg_name: str,
):
"""
Args:
hint: A type hint for a tensor, must be javtyping.AbstractArray.
arg: An argument to check, typically a tensor.
arg_name: The name of the argument, for error messages.
"""
# Check array types.
if backend.is_torch_available():
import torch

valid_array_types = (np.ndarray, torch.Tensor)
else:
valid_array_types = (np.ndarray,)
if not isinstance(arg, valid_array_types):
raise TypeError(
f"{arg_name} must be a tensor of type {valid_array_types}, "
f"but got type {type(arg)}."
)

# Check shapes.
gt_shape = _shape_from_dims(hint.dims)
if not _is_shape_compatible(arg.shape, gt_shape):
raise TypeError(
f"{arg_name} must be a tensor of shape {gt_shape}, "
f"but got shape {arg.shape}."
)

# Check dtype.
gt_dtypes = hint.dtypes
if _dtype_to_str(arg.dtype) not in gt_dtypes:
raise TypeError(
f"{arg_name} must be a tensor of dtype {gt_dtypes}, "
f"but got dtype {_dtype_to_str(arg.dtype)}."
)


def check_shape_and_dtype(func):
"""
A decorator to enforce type and shape specifications as per type hints.
"""

@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()

arg_name_to_arg = bound_args.arguments
arg_name_to_hint = typing.get_type_hints(func)

for arg_name, arg in arg_name_to_arg.items():
if arg_name in arg_name_to_hint:
hint = arg_name_to_hint[arg_name]
if issubclass(hint, jaxtyping.AbstractArray):
_assert_tensor_hint(hint, arg, arg_name)

return func(*args, **kwargs)

return wrapper
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ dependencies = [
"matplotlib>=3.3.4",
"scikit-image>=0.16.2",
"tqdm>=4.60.0",
"ivy>=0.0.8.0",
"jaxtyping>=0.2.12",
]
description = "CamTools: Camera Tools for Computer Vision."
license = {text = "MIT"}
Expand Down
Loading

0 comments on commit 798362c

Please sign in to comment.