Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial infrastructures for multi-backend support #62

Merged
merged 56 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
c48e94f
rename
yxlao Jun 28, 2024
3ed192d
add second test
yxlao Jun 28, 2024
a566355
use native backend
yxlao Jun 28, 2024
dd2cc2d
with native backend
yxlao Jun 28, 2024
8f1d9ee
avoid warning
yxlao Jun 29, 2024
ea5cf55
add jax typing
yxlao Jun 29, 2024
042fd29
expect raise
yxlao Jun 30, 2024
56760ff
first pasing type hints
yxlao Jun 30, 2024
1b72864
simplify
yxlao Jun 30, 2024
72436a0
type check as decorator
yxlao Jun 30, 2024
5baf108
move to sanity
yxlao Jun 30, 2024
2296e6b
check shape working
yxlao Jun 30, 2024
a582a36
add new and old code
yxlao Jun 30, 2024
9a66380
move out assert_tensor_hint
yxlao Jun 30, 2024
167984b
unpack hints
yxlao Jun 30, 2024
7bb270e
check array types and shapes
yxlao Jun 30, 2024
dbc3c25
check gt dtypes
yxlao Jun 30, 2024
4c6803e
add dependencies
yxlao Jun 30, 2024
e902c06
supress warning (temp) for numpy 2.0
yxlao Jun 30, 2024
c5659d7
import ivy from backend
yxlao Jun 30, 2024
abe343b
clean up
yxlao Jun 30, 2024
ca0cc66
checks dtype
yxlao Jun 30, 2024
96bc15f
rename
yxlao Jun 30, 2024
9c33b2f
check typecheck hints
yxlao Jun 30, 2024
f737d53
handle kwargs
yxlao Jun 30, 2024
2284f9f
docs
yxlao Jun 30, 2024
a7c00bc
update docs
yxlao Jun 30, 2024
9f79403
move to typing.py
yxlao Jul 1, 2024
c491aa4
remove union type checks
yxlao Jul 1, 2024
955f72c
docs
yxlao Jul 1, 2024
d9b9b73
better code inspection
yxlao Jul 1, 2024
f5c7016
better code inspection
yxlao Jul 1, 2024
4d17d93
comments
yxlao Jul 1, 2024
cd2bc92
issubclass(hint, jaxtyping.AbstractArray)
yxlao Jul 1, 2024
512aec8
initial list->tensor auto convert, dtype check has issues
yxlao Jul 1, 2024
79ec946
dtype test working
yxlao Jul 1, 2024
57b3d5c
test typing
yxlao Jul 1, 2024
abddd7e
format
yxlao Jul 1, 2024
5a1b60c
up dependencies
yxlao Jul 1, 2024
2cc6750
undo change
yxlao Jul 1, 2024
13d40b3
fix dependencies
yxlao Jul 1, 2024
542d3d1
change version
yxlao Jul 1, 2024
2dd4b12
torch
yxlao Jul 1, 2024
098f9dd
separate torch test with numpy test
yxlao Jul 1, 2024
dea9ffa
update tests
yxlao Jul 1, 2024
3e5e33b
reduce version
yxlao Jul 1, 2024
cdd8be2
test_named_dim_numpy and conditionally import jaxtyping internal modules
yxlao Jul 1, 2024
2cae122
comments
yxlao Jul 1, 2024
7971605
numpy test working
yxlao Jul 1, 2024
ce51202
add torch test
yxlao Jul 1, 2024
3fb998e
update version
yxlao Jul 1, 2024
dd0e8e9
switch back after creation
yxlao Jul 1, 2024
557dae0
remove einops dep
yxlao Jul 1, 2024
9d3b98a
organize imports
yxlao Jul 1, 2024
a997768
build: python version support update to 3.8 - 3.11
yxlao Jul 1, 2024
a8839db
Merge branch 'yixing/python-ver' into yixing/ivy
yxlao Jul 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- name: Set up Python ${{ matrix.python-version }}
Expand Down
10 changes: 8 additions & 2 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -25,6 +25,12 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
- name: Run unit tests
- name: Run unit tests (numpy)
run: |
pytest
- name: Install dependencies with torch
run: |
pip install -e .[torch]
- name: Run unit tests (numpy + torch)
run: |
pytest
3 changes: 2 additions & 1 deletion camtools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import artifact
from . import backend
from . import camera
from . import colmap
from . import colormap
Expand All @@ -14,9 +15,9 @@
from . import sanity
from . import solver
from . import transform
from . import typing
from . import util


try:
# Python >= 3.8
from importlib.metadata import version
Expand Down
84 changes: 84 additions & 0 deletions camtools/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import typing
import warnings
from functools import lru_cache, wraps
from inspect import signature
from typing import Literal

import ivy
import jaxtyping

# Internally use "from camtools.backend import ivy" to make sure ivy is imported
# after the warnings filter is set. This is a temporary workaround to suppress
# the deprecation warning from numpy 2.0.
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=".*numpy.core.numeric is deprecated.*",
)
import ivy

_default_backend = "numpy"


@lru_cache(maxsize=None)
def is_torch_available():
try:
import torch

return True
except ImportError:
return False


def set_backend(backend: Literal["numpy", "torch"]) -> None:
"""
Set the default backend for camtools.
"""
global _default_backend
_default_backend = backend


def get_backend() -> str:
"""
Get the default backend for camtools.
"""
return _default_backend


def with_native_backend(func):
"""
1. Enable default camtools backend.
2. Return native backend array (setting array mode to False).
3. Converts lists to tensors if the type hint is a tensor.
"""

@wraps(func)
def wrapper(*args, **kwargs):
og_backend = ivy.current_backend()
ct_backend = get_backend()
ivy.set_backend(ct_backend)

# Unpack args and type hints
sig = signature(func)
bound_args = sig.bind(*args, **kwargs)
arg_name_to_hint = typing.get_type_hints(func)

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with ivy.ArrayMode(False):
# Convert list -> native tensor if the type hint is a tensor
for arg_name, arg in bound_args.arguments.items():
if arg_name in arg_name_to_hint and issubclass(
arg_name_to_hint[arg_name], jaxtyping.AbstractArray
):
if isinstance(arg, list):
bound_args.arguments[arg_name] = ivy.native_array(arg)

# Call the function
result = func(*bound_args.args, **bound_args.kwargs)
finally:
ivy.set_backend(og_backend)
return result

return wrapper
164 changes: 164 additions & 0 deletions camtools/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import typing
from distutils.version import LooseVersion
from functools import wraps
from inspect import signature
from typing import Any, Tuple, Union

import jaxtyping
import numpy as np
import pkg_resources

from . import backend

try:
# Try to import _FixedDim and _NamedDim from jaxtyping. This are internal
# classes where the APIs are not stable. There are no guarantees that these
# classes will be available in future versions of jaxtyping.
_jaxtyping_version = pkg_resources.get_distribution("jaxtyping").version
if LooseVersion(_jaxtyping_version) >= LooseVersion("0.2.20"):
from jaxtyping._array_types import _FixedDim, _NamedDim
else:
from jaxtyping.array_types import _FixedDim, _NamedDim
except ImportError:
raise ImportError(
f"Failed to import _FixedDim and _NamedDim. with "
f"jaxtyping version {_jaxtyping_version}."
)


class Tensor:
"""
An abstract tensor type for type hinting only.
Typically np.ndarray or torch.Tensor is supported.
"""

pass


def _dtype_to_str(dtype):
"""
Convert numpy or torch dtype to string

- "bool"
- "bool_"
- "uint4"
- "uint8"
- "uint16"
- "uint32"
- "uint64"
- "int4"
- "int8"
- "int16"
- "int32"
- "int64"
- "bfloat16"
- "float16"
- "float32"
- "float64"
- "complex64"
- "complex128"
"""
if isinstance(dtype, np.dtype):
return dtype.name

if backend.is_torch_available():
import torch

if isinstance(dtype, torch.dtype):
return str(dtype).split(".")[1]

return ValueError(f"Unknown dtype {dtype}.")


def _shape_from_dims(
dims: Tuple[
Union[_FixedDim, _NamedDim],
...,
]
) -> Tuple[Union[int, None], ...]:
shape = []
for dim in dims:
if isinstance(dim, _FixedDim):
shape.append(dim.size)
elif isinstance(dim, _NamedDim):
shape.append(None)
return tuple(shape)


def _is_shape_compatible(
arg_shape: Tuple[Union[int, None], ...],
gt_shape: Tuple[Union[int, None], ...],
) -> bool:
if len(arg_shape) != len(gt_shape):
return False

return all(
arg_dim == gt_dim or gt_dim is None
for arg_dim, gt_dim in zip(arg_shape, gt_shape)
)


def _assert_tensor_hint(
hint: jaxtyping.AbstractArray,
arg: Any,
arg_name: str,
):
"""
Args:
hint: A type hint for a tensor, must be javtyping.AbstractArray.
arg: An argument to check, typically a tensor.
arg_name: The name of the argument, for error messages.
"""
# Check array types.
if backend.is_torch_available():
import torch

valid_array_types = (np.ndarray, torch.Tensor)
else:
valid_array_types = (np.ndarray,)
if not isinstance(arg, valid_array_types):
raise TypeError(
f"{arg_name} must be a tensor of type {valid_array_types}, "
f"but got type {type(arg)}."
)

# Check shapes.
gt_shape = _shape_from_dims(hint.dims)
if not _is_shape_compatible(arg.shape, gt_shape):
raise TypeError(
f"{arg_name} must be a tensor of shape {gt_shape}, "
f"but got shape {arg.shape}."
)

# Check dtype.
gt_dtypes = hint.dtypes
if _dtype_to_str(arg.dtype) not in gt_dtypes:
raise TypeError(
f"{arg_name} must be a tensor of dtype {gt_dtypes}, "
f"but got dtype {_dtype_to_str(arg.dtype)}."
)


def check_shape_and_dtype(func):
"""
A decorator to enforce type and shape specifications as per type hints.
"""

@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()

arg_name_to_arg = bound_args.arguments
arg_name_to_hint = typing.get_type_hints(func)

for arg_name, arg in arg_name_to_arg.items():
if arg_name in arg_name_to_hint:
hint = arg_name_to_hint[arg_name]
if issubclass(hint, jaxtyping.AbstractArray):
_assert_tensor_hint(hint, arg, arg_name)

return func(*args, **kwargs)

return wrapper
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ dependencies = [
"matplotlib>=3.3.4",
"scikit-image>=0.16.2",
"tqdm>=4.60.0",
"ivy>=0.0.8.0",
"jaxtyping>=0.2.12",
]
description = "CamTools: Camera Tools for Computer Vision."
license = {text = "MIT"}
name = "camtools"
readme = "README.md"
requires-python = ">=3.6.0"
requires-python = ">=3.8.0"
version = "0.1.4"

[project.scripts]
Expand Down
Loading
Loading