From c48e94f1b90454b6d36db84bdc08c5ac6389a8eb Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Fri, 28 Jun 2024 15:26:07 +0800 Subject: [PATCH 01/55] rename --- test/test_convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_convert.py b/test/test_convert.py index 2d28ce0..af9927a 100644 --- a/test/test_convert.py +++ b/test/test_convert.py @@ -5,7 +5,7 @@ np.set_printoptions(formatter={"float": "{: 0.2f}".format}) -def test_R_t_to_cameracenter(): +def test_R_t_to_C(): T = np.array( [ [0.132521, 0.00567408, 0.991163, 0.0228366], From 3ed192d38522733284ad69628f40be0713af0e0d Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Fri, 28 Jun 2024 15:26:39 +0800 Subject: [PATCH 02/55] add second test --- test/test_convert.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_convert.py b/test/test_convert.py index af9927a..e5ec195 100644 --- a/test/test_convert.py +++ b/test/test_convert.py @@ -23,6 +23,34 @@ def test_R_t_to_C(): ) +def R_t_to_C(R, t): + """ + Convert R, t to camera center + """ + t = t.reshape(-1, 3, 1) + R = R.reshape(-1, 3, 3) + C = -R.transpose(0, 2, 1) @ t + return C.squeeze() + + +def test_R_t_to_C_v2(): + T = np.array( + [ + [0.132521, 0.00567408, 0.991163, 0.0228366], + [-0.709094, -0.698155, 0.0988047, 0.535268], + [0.692546, -0.715923, -0.0884969, 16.0856], + [0, 0, 0, 1], + ] + ) + R = T[:3, :3] + t = T[:3, 3] + expected_camera_center = [-10.7635, 11.8896, 1.348] + camera_center = R_t_to_C(R, t) + np.testing.assert_allclose( + expected_camera_center, camera_center, rtol=1e-5, atol=1e-5 + ) + + def test_P_to_K_R_t(): def P_to_K_R_t_manual(P): """ From a56635555963fe45c4470dd4454f1f5e72aae5f1 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Fri, 28 Jun 2024 17:40:05 +0800 Subject: [PATCH 03/55] use native backend --- camtools/__init__.py | 1 + camtools/backend.py | 41 +++++++++++++++++++++++++++++++++++++++ test/test_backend.py | 46 ++++++++++++++++++++++++++++++++++++++++++++ test/test_convert.py | 18 ----------------- 4 files changed, 88 insertions(+), 18 deletions(-) create mode 100644 camtools/backend.py create mode 100644 test/test_backend.py diff --git a/camtools/__init__.py b/camtools/__init__.py index 3e40db6..d29cda7 100644 --- a/camtools/__init__.py +++ b/camtools/__init__.py @@ -1,4 +1,5 @@ from . import artifact +from . import backend from . import camera from . import colmap from . import colormap diff --git a/camtools/backend.py b/camtools/backend.py new file mode 100644 index 0000000..9daedfc --- /dev/null +++ b/camtools/backend.py @@ -0,0 +1,41 @@ +from typing import Literal +from functools import wraps +import ivy + +_default_backend = "numpy" + + +def set_backend(backend: Literal["numpy", "torch"]) -> None: + """ + Set the default backend for camtools. + """ + global _default_backend + _default_backend = backend + + +def get_backend() -> str: + """ + Get the default backend for camtools. + """ + return _default_backend + + +def with_backend(func): + """ + Decorator to run a function with: 1) default camtools backend, 2) with + native backend array (setting array mode to False). + """ + + @wraps(func) + def wrapper(*args, **kwargs): + og_backend = ivy.current_backend() + ct_backend = get_backend() + ivy.set_backend(ct_backend) + try: + with ivy.ArrayMode(False): + result = func(*args, **kwargs) + finally: + ivy.set_backend(og_backend) + return result + + return wrapper diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 0000000..d51f1ce --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,46 @@ +""" +Test basic usage of ivy and its interaction with numpy and torch. +""" + +import ivy +import numpy as np +import torch +import einops +import camtools as ct + + +@ct.backend.with_backend +def creation(): + zeros = ivy.zeros([2, 3]) + return zeros + + +def test_creation(): + zeros = creation() + import ipdb + + ipdb.set_trace() + pass + + # # Default backend + # zeros = ivy.zeros([2, 3]) + # assert zeros.backend == ivy.current_backend().backend + # assert zeros.backend == "numpy" + # assert zeros.dtype == ivy.float32 + # assert zeros.shape == (2, 3) + # zeros = zeros.to_native() + + # import ipdb + + # ipdb.set_trace() + # pass + + # # Explicit numpy + # ivy.set_backend("numpy") + # zeros = ivy.zeros([2, 3]) + # assert isinstance(zeros.data, np.ndarray) + + # pass + + +ivy.set_array_mode diff --git a/test/test_convert.py b/test/test_convert.py index e5ec195..f229a21 100644 --- a/test/test_convert.py +++ b/test/test_convert.py @@ -33,24 +33,6 @@ def R_t_to_C(R, t): return C.squeeze() -def test_R_t_to_C_v2(): - T = np.array( - [ - [0.132521, 0.00567408, 0.991163, 0.0228366], - [-0.709094, -0.698155, 0.0988047, 0.535268], - [0.692546, -0.715923, -0.0884969, 16.0856], - [0, 0, 0, 1], - ] - ) - R = T[:3, :3] - t = T[:3, 3] - expected_camera_center = [-10.7635, 11.8896, 1.348] - camera_center = R_t_to_C(R, t) - np.testing.assert_allclose( - expected_camera_center, camera_center, rtol=1e-5, atol=1e-5 - ) - - def test_P_to_K_R_t(): def P_to_K_R_t_manual(P): """ From dd2cc2d7ab8baf69d4d2031f52a93278a943b919 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Fri, 28 Jun 2024 17:40:37 +0800 Subject: [PATCH 04/55] with native backend --- camtools/backend.py | 2 +- test/test_backend.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index 9daedfc..689232f 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -20,7 +20,7 @@ def get_backend() -> str: return _default_backend -def with_backend(func): +def with_native_backend(func): """ Decorator to run a function with: 1) default camtools backend, 2) with native backend array (setting array mode to False). diff --git a/test/test_backend.py b/test/test_backend.py index d51f1ce..5c60e76 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -9,7 +9,7 @@ import camtools as ct -@ct.backend.with_backend +@ct.backend.with_native_backend def creation(): zeros = ivy.zeros([2, 3]) return zeros From 8f1d9ee80edc0751c68fa90949694b2c111f9fc3 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sat, 29 Jun 2024 21:55:59 +0800 Subject: [PATCH 05/55] avoid warning --- camtools/backend.py | 13 +++++++++++-- test/test_backend.py | 38 ++++++-------------------------------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index 689232f..e06d139 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -1,6 +1,7 @@ from typing import Literal from functools import wraps import ivy +import warnings _default_backend = "numpy" @@ -32,8 +33,16 @@ def wrapper(*args, **kwargs): ct_backend = get_backend() ivy.set_backend(ct_backend) try: - with ivy.ArrayMode(False): - result = func(*args, **kwargs) + with warnings.catch_warnings(): + """ + Possible warning: + UserWarning: In the case of Compositional function, operators + might cause inconsistent behavior when array_mode is set to + False. + """ + warnings.simplefilter("ignore", category=UserWarning) + with ivy.ArrayMode(False): + result = func(*args, **kwargs) finally: ivy.set_backend(og_backend) return result diff --git a/test/test_backend.py b/test/test_backend.py index 5c60e76..d69ab27 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -9,38 +9,12 @@ import camtools as ct -@ct.backend.with_native_backend -def creation(): - zeros = ivy.zeros([2, 3]) - return zeros - - def test_creation(): - zeros = creation() - import ipdb - - ipdb.set_trace() - pass - - # # Default backend - # zeros = ivy.zeros([2, 3]) - # assert zeros.backend == ivy.current_backend().backend - # assert zeros.backend == "numpy" - # assert zeros.dtype == ivy.float32 - # assert zeros.shape == (2, 3) - # zeros = zeros.to_native() - - # import ipdb - - # ipdb.set_trace() - # pass - - # # Explicit numpy - # ivy.set_backend("numpy") - # zeros = ivy.zeros([2, 3]) - # assert isinstance(zeros.data, np.ndarray) - - # pass + @ct.backend.with_native_backend + def creation(): + zeros = ivy.zeros([2, 3]) + return zeros -ivy.set_array_mode + tensor = creation() + assert isinstance(tensor, np.ndarray) From ea5cf55c4adda2f9dd5f122c28f1947c6044dc22 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sat, 29 Jun 2024 22:13:41 +0800 Subject: [PATCH 06/55] add jax typing --- test/test_backend.py | 50 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/test/test_backend.py b/test/test_backend.py index d69ab27..91ee513 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,14 +7,64 @@ import torch import einops import camtools as ct +from jaxtyping import Array, Float, UInt8 def test_creation(): + """ + Test tensor creation. + """ @ct.backend.with_native_backend def creation(): zeros = ivy.zeros([2, 3]) return zeros + # Default backend is numpy + assert ct.backend.get_backend() == "numpy" tensor = creation() assert isinstance(tensor, np.ndarray) + assert tensor.shape == (2, 3) + assert tensor.dtype == np.float32 + + # Switch to torch backend + ct.backend.set_backend("torch") + assert ct.backend.get_backend() == "torch" + tensor = creation() + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (2, 3) + assert tensor.dtype == torch.float32 + ct.backend.set_backend("numpy") + + +def test_arguments(): + """ + Test taking arguments from functions. + """ + + @ct.backend.with_native_backend + def add(x, y): + return x + y + + # Default backend is numpy + assert ct.backend.get_backend() == "numpy" + src_x = np.ones([2, 3]) * 2 + src_y = np.ones([1, 3]) * 3 + dst_expected = np.ones([2, 3]) * 5 + dst = add(src_x, src_y) + np.testing.assert_allclose(dst, dst_expected, rtol=1e-5, atol=1e-5) + + # Mixed backend argument should raise error + src_x = np.ones([2, 3]) * 2 + src_y = torch.ones([1, 3]) * 3 + add(src_x, src_y) + + +def test_type_hint_arguments(): + """ + Test type hinting arguments. + """ + + @ct.backend.with_native_backend + def add(x: Float[Array, 2, 3], y: Float[Array, 1, 3]) -> Float[Array, 2, 3]: + return x + y From 042fd299990f49a7f6ded45cd16cc2f02403fbec Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 18:29:35 +0800 Subject: [PATCH 07/55] expect raise --- camtools/__init__.py | 1 - test/test_backend.py | 56 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/camtools/__init__.py b/camtools/__init__.py index d29cda7..2b0c3e6 100644 --- a/camtools/__init__.py +++ b/camtools/__init__.py @@ -17,7 +17,6 @@ from . import transform from . import util - try: # Python >= 3.8 from importlib.metadata import version diff --git a/test/test_backend.py b/test/test_backend.py index 91ee513..321a3fa 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,10 @@ import torch import einops import camtools as ct -from jaxtyping import Array, Float, UInt8 +from jaxtyping import Float, UInt8 +import typing +import pytest +from numpy.typing import NDArray def test_creation(): @@ -57,14 +60,49 @@ def add(x, y): # Mixed backend argument should raise error src_x = np.ones([2, 3]) * 2 src_y = torch.ones([1, 3]) * 3 - add(src_x, src_y) + with pytest.raises(TypeError): + add(src_x, src_y) -def test_type_hint_arguments(): - """ - Test type hinting arguments. - """ +# def test_type_hint_arguments(): +# """ +# Test type hinting arguments. +# """ - @ct.backend.with_native_backend - def add(x: Float[Array, 2, 3], y: Float[Array, 1, 3]) -> Float[Array, 2, 3]: - return x + y +# @ct.backend.with_native_backend +# def add( +# x: Float[np.ndarray, "2 3"], y: Float[np.ndarray, "1 3"] +# ) -> Float[np.ndarray, "2 3"]: +# # Extract type hints +# hints = typing.get_type_hints(add) +# x_hint = hints["x"] +# y_hint = hints["y"] + +# # Extract shapes from the type hints +# x_shape = einops.parse_shape(x, x_hint) +# y_shape = einops.parse_shape(y, y_hint) + +# # Verify the input types and shapes +# if not (isinstance(x, (np.ndarray, torch.Tensor)) and x.shape == x_shape): +# raise TypeError(f"x must be a tensor of shape {x_shape}") +# if not (isinstance(y, (np.ndarray, torch.Tensor)) and y.shape == y_shape): +# raise TypeError(f"y must be a tensor of shape {y_shape}") + +# return x + y + +# # Test with correct types and shapes using np.array directly marked as float32 +# x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) +# y = np.array([[1, 1, 1]], dtype=np.float32) +# result = add(x, y) +# expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) +# assert np.allclose(result, expected, atol=1e-5) + +# # Testing with incorrect shapes +# with pytest.raises(TypeError): +# x_wrong = np.array([[1, 2], [4, 5]], dtype=np.float32) +# add(x_wrong, y) + +# # Testing with incorrect types +# with pytest.raises(TypeError): +# x_wrong_type = [[1, 2, 3], [4, 5, 6]] +# add(x_wrong_type, y) From 56760ff73c3b8e94e9627aa9a67af711f583d795 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 18:40:06 +0800 Subject: [PATCH 08/55] first pasing type hints --- test/test_backend.py | 116 +++++++++++++++++++++++++++---------------- 1 file changed, 74 insertions(+), 42 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index 321a3fa..298bb00 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -12,6 +12,8 @@ import pytest from numpy.typing import NDArray +from jaxtyping import Float, _array_types + def test_creation(): """ @@ -64,45 +66,75 @@ def add(x, y): add(src_x, src_y) -# def test_type_hint_arguments(): -# """ -# Test type hinting arguments. -# """ - -# @ct.backend.with_native_backend -# def add( -# x: Float[np.ndarray, "2 3"], y: Float[np.ndarray, "1 3"] -# ) -> Float[np.ndarray, "2 3"]: -# # Extract type hints -# hints = typing.get_type_hints(add) -# x_hint = hints["x"] -# y_hint = hints["y"] - -# # Extract shapes from the type hints -# x_shape = einops.parse_shape(x, x_hint) -# y_shape = einops.parse_shape(y, y_hint) - -# # Verify the input types and shapes -# if not (isinstance(x, (np.ndarray, torch.Tensor)) and x.shape == x_shape): -# raise TypeError(f"x must be a tensor of shape {x_shape}") -# if not (isinstance(y, (np.ndarray, torch.Tensor)) and y.shape == y_shape): -# raise TypeError(f"y must be a tensor of shape {y_shape}") - -# return x + y - -# # Test with correct types and shapes using np.array directly marked as float32 -# x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) -# y = np.array([[1, 1, 1]], dtype=np.float32) -# result = add(x, y) -# expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) -# assert np.allclose(result, expected, atol=1e-5) - -# # Testing with incorrect shapes -# with pytest.raises(TypeError): -# x_wrong = np.array([[1, 2], [4, 5]], dtype=np.float32) -# add(x_wrong, y) - -# # Testing with incorrect types -# with pytest.raises(TypeError): -# x_wrong_type = [[1, 2, 3], [4, 5, 6]] -# add(x_wrong_type, y) +def get_shape_from_hint(type_hint): + # This function assumes type hints are provided as 'Float[Array, "2 3"]' + # and extracts the shape part as a tuple of integers. + if hasattr(type_hint, "__args__") and type_hint.__args__: + shape_str = type_hint.__args__[1] # Access the shape string + return tuple(map(int, shape_str.split())) + return None + + +def test_type_hint_arguments(): + """ + Test type hinting arguments. + """ + + def add( + x: Float[np.ndarray, "2 3"], y: Float[np.ndarray, "1 3"] + ) -> Float[np.ndarray, "2 3"]: + # Extract type hints + hints = typing.get_type_hints(add) + x_hint = hints["x"] + y_hint = hints["y"] + + # Function to convert dims into shape tuples, handling named and fixed dimensions + def get_shape(dims): + shape = [] + for dim in dims: + if isinstance(dim, _array_types._FixedDim): + shape.append(dim.size) + elif isinstance(dim, _array_types._NamedDim): + shape.append( + None + ) # Use None or another placeholder for variable dimensions + return tuple(shape) + + # Obtain shapes from the type hints' dims attribute + x_shape = get_shape(x_hint.dims) + y_shape = get_shape(y_hint.dims) + + # Verify the input types and shapes + if not (isinstance(x, (np.ndarray, torch.Tensor))): + raise TypeError(f"x must be a tensor") + if not all( + x_dim == shape_dim or shape_dim is None + for x_dim, shape_dim in zip(x.shape, x_shape) + ): + raise TypeError(f"x must be a tensor of shape {x_shape}") + if not (isinstance(y, (np.ndarray, torch.Tensor))): + raise TypeError(f"y must be a tensor") + if not all( + y_dim == shape_dim or shape_dim is None + for y_dim, shape_dim in zip(y.shape, y_shape) + ): + raise TypeError(f"y must be a tensor of shape {y_shape}") + + return x + y + + # Test with correct types and shapes using np.array directly marked as float32 + x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + y = np.array([[1, 1, 1]], dtype=np.float32) + result = add(x, y) + expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) + assert np.allclose(result, expected, atol=1e-5) + + # Testing with incorrect shapes + with pytest.raises(TypeError): + x_wrong = np.array([[1, 2], [4, 5]], dtype=np.float32) + add(x_wrong, y) + + # Testing with incorrect types + with pytest.raises(TypeError): + x_wrong_type = [[1, 2, 3], [4, 5, 6]] # not a NumPy array + add(x_wrong_type, y) From 1b72864abec5d0a822c7e3b184912cc4022d8acf Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 18:41:15 +0800 Subject: [PATCH 09/55] simplify --- test/test_backend.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index 298bb00..da68a12 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -80,24 +80,22 @@ def test_type_hint_arguments(): Test type hinting arguments. """ + @ct.backend.with_native_backend def add( - x: Float[np.ndarray, "2 3"], y: Float[np.ndarray, "1 3"] + x: Float[np.ndarray, "2 3"], + y: Float[np.ndarray, "1 3"], ) -> Float[np.ndarray, "2 3"]: - # Extract type hints hints = typing.get_type_hints(add) x_hint = hints["x"] y_hint = hints["y"] - # Function to convert dims into shape tuples, handling named and fixed dimensions def get_shape(dims): shape = [] for dim in dims: if isinstance(dim, _array_types._FixedDim): shape.append(dim.size) elif isinstance(dim, _array_types._NamedDim): - shape.append( - None - ) # Use None or another placeholder for variable dimensions + shape.append(None) return tuple(shape) # Obtain shapes from the type hints' dims attribute From 72436a06420b313fc6c1e1aba62d16fbfcfbda58 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 18:43:42 +0800 Subject: [PATCH 10/55] type check as decorator --- test/test_backend.py | 77 +++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index da68a12..b954aaa 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -12,6 +12,7 @@ import pytest from numpy.typing import NDArray +from functools import wraps from jaxtyping import Float, _array_types @@ -75,52 +76,60 @@ def get_shape_from_hint(type_hint): return None +def check_shape_and_dtype(func): + """ + A decorator to enforce type and shape specifications as per type hints. + """ + + def get_shape(dims): + shape = [] + for dim in dims: + if isinstance(dim, _array_types._FixedDim): + shape.append(dim.size) + elif isinstance(dim, _array_types._NamedDim): + shape.append(None) + return tuple(shape) + + @wraps(func) + def wrapper(*args, **kwargs): + hints = typing.get_type_hints(func) + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + + for arg_name, arg_value in zip(arg_names, args): + if arg_name in hints: + hint = hints[arg_name] + expected_shape = get_shape(hint.dims) + + if not (isinstance(arg_value, (np.ndarray, torch.Tensor))): + raise TypeError(f"{arg_name} must be a tensor") + + if not all( + actual_dim == expected_dim or expected_dim is None + for actual_dim, expected_dim in zip(arg_value.shape, expected_shape) + ): + raise TypeError( + f"{arg_name} must be a tensor of shape {expected_shape}" + ) + + return func(*args, **kwargs) + + return wrapper + + def test_type_hint_arguments(): """ Test type hinting arguments. """ @ct.backend.with_native_backend + @check_shape_and_dtype def add( x: Float[np.ndarray, "2 3"], y: Float[np.ndarray, "1 3"], ) -> Float[np.ndarray, "2 3"]: - hints = typing.get_type_hints(add) - x_hint = hints["x"] - y_hint = hints["y"] - - def get_shape(dims): - shape = [] - for dim in dims: - if isinstance(dim, _array_types._FixedDim): - shape.append(dim.size) - elif isinstance(dim, _array_types._NamedDim): - shape.append(None) - return tuple(shape) - - # Obtain shapes from the type hints' dims attribute - x_shape = get_shape(x_hint.dims) - y_shape = get_shape(y_hint.dims) - - # Verify the input types and shapes - if not (isinstance(x, (np.ndarray, torch.Tensor))): - raise TypeError(f"x must be a tensor") - if not all( - x_dim == shape_dim or shape_dim is None - for x_dim, shape_dim in zip(x.shape, x_shape) - ): - raise TypeError(f"x must be a tensor of shape {x_shape}") - if not (isinstance(y, (np.ndarray, torch.Tensor))): - raise TypeError(f"y must be a tensor") - if not all( - y_dim == shape_dim or shape_dim is None - for y_dim, shape_dim in zip(y.shape, y_shape) - ): - raise TypeError(f"y must be a tensor of shape {y_shape}") - return x + y - # Test with correct types and shapes using np.array directly marked as float32 + # Default backend is numpy x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) y = np.array([[1, 1, 1]], dtype=np.float32) result = add(x, y) From 5baf108dd0cdaa8afb27d2f4cbca705ebfff8031 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 18:45:40 +0800 Subject: [PATCH 11/55] move to sanity --- camtools/sanity.py | 52 ++++++++++++++++++++++++++++++++++++++++++++ test/test_backend.py | 19 +++++++--------- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index 4fe66f0..d3d44d6 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -1,4 +1,56 @@ import numpy as np +import numpy as np +import torch +import typing + +from functools import wraps +from jaxtyping import _array_types + + +def check_shape_and_dtype(func): + """ + A decorator to enforce type and shape specifications as per type hints. + """ + + def get_shape(dims): + shape = [] + for dim in dims: + if isinstance(dim, _array_types._FixedDim): + shape.append(dim.size) + elif isinstance(dim, _array_types._NamedDim): + shape.append(None) + return tuple(shape) + + @wraps(func) + def wrapper(*args, **kwargs): + hints = typing.get_type_hints(func) + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + + for arg_name, arg_value in zip(arg_names, args): + if arg_name in hints: + hint = hints[arg_name] + expected_shape = get_shape(hint.dims) + + if not (isinstance(arg_value, (np.ndarray, torch.Tensor))): + raise TypeError(f"{arg_name} must be a tensor") + + if not all( + actual_dim == expected_dim or expected_dim is None + for ( + actual_dim, + expected_dim, + ) in zip( + arg_value.shape, + expected_shape, + ) + ): + raise TypeError( + f"{arg_name} must be a tensor of shape {expected_shape}" + ) + + return func(*args, **kwargs) + + return wrapper def assert_numpy(x, name=None): diff --git a/test/test_backend.py b/test/test_backend.py index b954aaa..a7a0a25 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -67,15 +67,6 @@ def add(x, y): add(src_x, src_y) -def get_shape_from_hint(type_hint): - # This function assumes type hints are provided as 'Float[Array, "2 3"]' - # and extracts the shape part as a tuple of integers. - if hasattr(type_hint, "__args__") and type_hint.__args__: - shape_str = type_hint.__args__[1] # Access the shape string - return tuple(map(int, shape_str.split())) - return None - - def check_shape_and_dtype(func): """ A decorator to enforce type and shape specifications as per type hints. @@ -105,7 +96,13 @@ def wrapper(*args, **kwargs): if not all( actual_dim == expected_dim or expected_dim is None - for actual_dim, expected_dim in zip(arg_value.shape, expected_shape) + for ( + actual_dim, + expected_dim, + ) in zip( + arg_value.shape, + expected_shape, + ) ): raise TypeError( f"{arg_name} must be a tensor of shape {expected_shape}" @@ -122,7 +119,7 @@ def test_type_hint_arguments(): """ @ct.backend.with_native_backend - @check_shape_and_dtype + @ct.sanity.check_shape_and_dtype def add( x: Float[np.ndarray, "2 3"], y: Float[np.ndarray, "1 3"], From 2296e6b80dbefb340bc4e4813dc8664c4105f79b Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 19:00:15 +0800 Subject: [PATCH 12/55] check shape working --- camtools/sanity.py | 24 +++++++++++------------- test/test_backend.py | 4 ++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index d3d44d6..f05caea 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -5,6 +5,7 @@ from functools import wraps from jaxtyping import _array_types +from typing import Tuple, Union def check_shape_and_dtype(func): @@ -12,7 +13,9 @@ def check_shape_and_dtype(func): A decorator to enforce type and shape specifications as per type hints. """ - def get_shape(dims): + def get_shape( + dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] + ) -> Tuple[Union[int, None], ...]: shape = [] for dim in dims: if isinstance(dim, _array_types._FixedDim): @@ -26,26 +29,21 @@ def wrapper(*args, **kwargs): hints = typing.get_type_hints(func) arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] - for arg_name, arg_value in zip(arg_names, args): + for arg_name, arg in zip(arg_names, args): if arg_name in hints: hint = hints[arg_name] - expected_shape = get_shape(hint.dims) + gt_shape = get_shape(hint.dims) - if not (isinstance(arg_value, (np.ndarray, torch.Tensor))): + if not (isinstance(arg, (np.ndarray, torch.Tensor))): raise TypeError(f"{arg_name} must be a tensor") if not all( - actual_dim == expected_dim or expected_dim is None - for ( - actual_dim, - expected_dim, - ) in zip( - arg_value.shape, - expected_shape, - ) + arg_dim == gt_dim or gt_dim is None + for arg_dim, gt_dim in zip(arg.shape, gt_shape) ): raise TypeError( - f"{arg_name} must be a tensor of shape {expected_shape}" + f"{arg_name} must be a tensor of shape {gt_shape}, " + f"but got shape {arg.shape}." ) return func(*args, **kwargs) diff --git a/test/test_backend.py b/test/test_backend.py index a7a0a25..71e49c4 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -135,8 +135,8 @@ def add( # Testing with incorrect shapes with pytest.raises(TypeError): - x_wrong = np.array([[1, 2], [4, 5]], dtype=np.float32) - add(x_wrong, y) + y_wrong = np.array([[1, 1, 1, 1]], dtype=np.float32) + add(x, y_wrong) # Testing with incorrect types with pytest.raises(TypeError): From a582a36e165c2d5af22150949c0ccaac0d434c92 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 19:19:28 +0800 Subject: [PATCH 13/55] add new and old code --- camtools/sanity.py | 83 ++++++++++++++++++++++++++++++++++++++------ test/test_backend.py | 5 +++ 2 files changed, 77 insertions(+), 11 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index f05caea..8461ca5 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -7,23 +7,26 @@ from jaxtyping import _array_types from typing import Tuple, Union +from typing import Union, Tuple, get_args + + +def get_shape( + dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] +) -> Tuple[Union[int, None], ...]: + shape = [] + for dim in dims: + if isinstance(dim, _array_types._FixedDim): + shape.append(dim.size) + elif isinstance(dim, _array_types._NamedDim): + shape.append(None) + return tuple(shape) + def check_shape_and_dtype(func): """ A decorator to enforce type and shape specifications as per type hints. """ - def get_shape( - dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] - ) -> Tuple[Union[int, None], ...]: - shape = [] - for dim in dims: - if isinstance(dim, _array_types._FixedDim): - shape.append(dim.size) - elif isinstance(dim, _array_types._NamedDim): - shape.append(None) - return tuple(shape) - @wraps(func) def wrapper(*args, **kwargs): hints = typing.get_type_hints(func) @@ -51,6 +54,64 @@ def wrapper(*args, **kwargs): return wrapper +# def get_shape( +# dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] +# ) -> Tuple[Union[int, None], ...]: +# shape = [] +# for dim in dims: +# if isinstance(dim, _array_types._FixedDim): +# shape.append(dim.size) +# elif isinstance(dim, _array_types._NamedDim): +# shape.append(None) +# return tuple(shape) + + +# def check_shape_and_dtype(func): +# @wraps(func) +# def wrapper(*args, **kwargs): +# hints = typing.get_type_hints(func) +# for arg_name, arg in zip( +# func.__code__.co_varnames[: func.__code__.co_argcount], args +# ): +# if arg_name in hints: +# hint = hints[arg_name] +# if getattr(hint, "__origin__", None) is Union: +# possible_types = get_args(hint) +# else: +# possible_types = (hint,) + +# valid = False +# for possible_type in possible_types: + +# # We assume that jaxtyping types wrap around actual types +# if hasattr(possible_type, "__args__") and possible_type.__args__: +# actual_type = ( +# possible_type.__args__[0] +# if possible_type.__args__[0] in {np.ndarray, torch.Tensor} +# else None +# ) +# if actual_type and isinstance(arg, actual_type): +# gt_shape = ( +# get_shape(possible_type.__args__[1].dims) +# if len(possible_type.__args__) > 1 +# else None +# ) +# if gt_shape is None or all( +# arg_dim == gt_dim or gt_dim is None +# for arg_dim, gt_dim in zip(arg.shape, gt_shape) +# ): +# valid = True +# break + +# if not valid: +# raise TypeError( +# f"{arg_name} must match one of the specified types and shapes." +# ) +# return func(*args, **kwargs) + +# return wrapper + + def assert_numpy(x, name=None): if not isinstance(x, np.ndarray): maybe_name = f" {name}" if name is not None else "" diff --git a/test/test_backend.py b/test/test_backend.py index 71e49c4..92bd172 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -15,6 +15,8 @@ from functools import wraps from jaxtyping import Float, _array_types +from typing import Union + def test_creation(): """ @@ -113,6 +115,9 @@ def wrapper(*args, **kwargs): return wrapper +# Tensor = Union[np.ndarray, torch.Tensor] + + def test_type_hint_arguments(): """ Test type hinting arguments. From 9a66380df38d4ec1487cb3ccf61d84ded09e80a5 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 19:22:35 +0800 Subject: [PATCH 14/55] move out assert_tensor_hint --- camtools/sanity.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index 8461ca5..75f0c05 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -22,6 +22,22 @@ def get_shape( return tuple(shape) +def assert_tensor_hint(hint, arg, arg_name): + gt_shape = get_shape(hint.dims) + + if not (isinstance(arg, (np.ndarray, torch.Tensor))): + raise TypeError(f"{arg_name} must be a tensor") + + if not all( + arg_dim == gt_dim or gt_dim is None + for arg_dim, gt_dim in zip(arg.shape, gt_shape) + ): + raise TypeError( + f"{arg_name} must be a tensor of shape {gt_shape}, " + f"but got shape {arg.shape}." + ) + + def check_shape_and_dtype(func): """ A decorator to enforce type and shape specifications as per type hints. @@ -35,19 +51,7 @@ def wrapper(*args, **kwargs): for arg_name, arg in zip(arg_names, args): if arg_name in hints: hint = hints[arg_name] - gt_shape = get_shape(hint.dims) - - if not (isinstance(arg, (np.ndarray, torch.Tensor))): - raise TypeError(f"{arg_name} must be a tensor") - - if not all( - arg_dim == gt_dim or gt_dim is None - for arg_dim, gt_dim in zip(arg.shape, gt_shape) - ): - raise TypeError( - f"{arg_name} must be a tensor of shape {gt_shape}, " - f"but got shape {arg.shape}." - ) + assert_tensor_hint(hint, arg, arg_name) return func(*args, **kwargs) From 167984b4f2d63b5c9984fa1e91e1c629839a0483 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 19:30:49 +0800 Subject: [PATCH 15/55] unpack hints --- camtools/sanity.py | 85 ++++++++++---------------------------------- test/test_backend.py | 8 ++--- 2 files changed, 23 insertions(+), 70 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index 75f0c05..040ccaa 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -23,15 +23,26 @@ def get_shape( def assert_tensor_hint(hint, arg, arg_name): - gt_shape = get_shape(hint.dims) - if not (isinstance(arg, (np.ndarray, torch.Tensor))): - raise TypeError(f"{arg_name} must be a tensor") - - if not all( - arg_dim == gt_dim or gt_dim is None - for arg_dim, gt_dim in zip(arg.shape, gt_shape) - ): + if getattr(hint, "__origin__", None) is Union: + unpacked_hints = get_args(hint) + else: + unpacked_hints = (hint,) + + valid = False + + for unpacked_hint in unpacked_hints: + if not (isinstance(arg, (np.ndarray, torch.Tensor))): + raise TypeError(f"{arg_name} must be a tensor") + gt_shape = get_shape(unpacked_hint.dims) + if all( + arg_dim == gt_dim or gt_dim is None + for arg_dim, gt_dim in zip(arg.shape, gt_shape) + ): + valid = True + break + + if not valid: raise TypeError( f"{arg_name} must be a tensor of shape {gt_shape}, " f"but got shape {arg.shape}." @@ -58,64 +69,6 @@ def wrapper(*args, **kwargs): return wrapper -# def get_shape( -# dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] -# ) -> Tuple[Union[int, None], ...]: -# shape = [] -# for dim in dims: -# if isinstance(dim, _array_types._FixedDim): -# shape.append(dim.size) -# elif isinstance(dim, _array_types._NamedDim): -# shape.append(None) -# return tuple(shape) - - -# def check_shape_and_dtype(func): -# @wraps(func) -# def wrapper(*args, **kwargs): -# hints = typing.get_type_hints(func) -# for arg_name, arg in zip( -# func.__code__.co_varnames[: func.__code__.co_argcount], args -# ): -# if arg_name in hints: -# hint = hints[arg_name] -# if getattr(hint, "__origin__", None) is Union: -# possible_types = get_args(hint) -# else: -# possible_types = (hint,) - -# valid = False -# for possible_type in possible_types: - -# # We assume that jaxtyping types wrap around actual types -# if hasattr(possible_type, "__args__") and possible_type.__args__: -# actual_type = ( -# possible_type.__args__[0] -# if possible_type.__args__[0] in {np.ndarray, torch.Tensor} -# else None -# ) -# if actual_type and isinstance(arg, actual_type): -# gt_shape = ( -# get_shape(possible_type.__args__[1].dims) -# if len(possible_type.__args__) > 1 -# else None -# ) -# if gt_shape is None or all( -# arg_dim == gt_dim or gt_dim is None -# for arg_dim, gt_dim in zip(arg.shape, gt_shape) -# ): -# valid = True -# break - -# if not valid: -# raise TypeError( -# f"{arg_name} must match one of the specified types and shapes." -# ) -# return func(*args, **kwargs) - -# return wrapper - - def assert_numpy(x, name=None): if not isinstance(x, np.ndarray): maybe_name = f" {name}" if name is not None else "" diff --git a/test/test_backend.py b/test/test_backend.py index 92bd172..2fb171f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -115,7 +115,7 @@ def wrapper(*args, **kwargs): return wrapper -# Tensor = Union[np.ndarray, torch.Tensor] +Tensor = Union[np.ndarray, torch.Tensor] def test_type_hint_arguments(): @@ -126,9 +126,9 @@ def test_type_hint_arguments(): @ct.backend.with_native_backend @ct.sanity.check_shape_and_dtype def add( - x: Float[np.ndarray, "2 3"], - y: Float[np.ndarray, "1 3"], - ) -> Float[np.ndarray, "2 3"]: + x: Float[Tensor, "2 3"], + y: Float[Tensor, "1 3"], + ) -> Float[Tensor, "2 3"]: return x + y # Default backend is numpy From 7bb270e8195a53dd74d5a411f7ce994d651947fb Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 23:19:00 +0800 Subject: [PATCH 16/55] check array types and shapes --- camtools/sanity.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index 040ccaa..911c41c 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -29,20 +29,28 @@ def assert_tensor_hint(hint, arg, arg_name): else: unpacked_hints = (hint,) - valid = False + # Check array types (e.g. np.ndarray, torch.Tensor, ...) + valid_array_types = tuple( + unpacked_hint.array_type for unpacked_hint in unpacked_hints + ) + if not isinstance(arg, valid_array_types): + raise TypeError( + f"{arg_name} must be a tensor of type {valid_array_types}, " + f"but got type {type(arg)}." + ) + # Check shapes. + gt_shape = get_shape(unpacked_hints[0].dims) for unpacked_hint in unpacked_hints: - if not (isinstance(arg, (np.ndarray, torch.Tensor))): - raise TypeError(f"{arg_name} must be a tensor") - gt_shape = get_shape(unpacked_hint.dims) - if all( - arg_dim == gt_dim or gt_dim is None - for arg_dim, gt_dim in zip(arg.shape, gt_shape) - ): - valid = True - break - - if not valid: + if get_shape(unpacked_hint.dims) != gt_shape: + raise TypeError( + f"Internal error: all shapes in the Union must be the same, " + f"but got {gt_shape} and {get_shape(unpacked_hint.dims)}." + ) + if not all( + arg_dim == gt_dim or gt_dim is None + for arg_dim, gt_dim in zip(arg.shape, gt_shape) + ): raise TypeError( f"{arg_name} must be a tensor of shape {gt_shape}, " f"but got shape {arg.shape}." From dbc3c257c6a77491b158e81d9e16537f90d47282 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Sun, 30 Jun 2024 23:33:04 +0800 Subject: [PATCH 17/55] check gt dtypes --- camtools/sanity.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/camtools/sanity.py b/camtools/sanity.py index 911c41c..cbaa767 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -10,6 +10,37 @@ from typing import Union, Tuple, get_args +def dtype_to_str(dtype): + """ + Convert numpy or torch dtype to string + + - "bool" + - "bool_" + - "uint4" + - "uint8" + - "uint16" + - "uint32" + - "uint64" + - "int4" + - "int8" + - "int16" + - "int32" + - "int64" + - "bfloat16" + - "float16" + - "float32" + - "float64" + - "complex64" + - "complex128" + """ + if isinstance(dtype, np.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + return str(dtype).split(".")[1] + else: + raise ValueError(f"Unknown dtype {dtype}.") + + def get_shape( dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] ) -> Tuple[Union[int, None], ...]: @@ -56,6 +87,15 @@ def assert_tensor_hint(hint, arg, arg_name): f"but got shape {arg.shape}." ) + # Check dtype. + gt_dtypes = unpacked_hints[0].dtypes + for unpacked_hint in unpacked_hints: + if unpacked_hint.dtypes != gt_dtypes: + raise TypeError( + f"Internal error: all dtypes in the Union must be the same, " + f"but got {gt_dtypes} and {unpacked_hint.dtypes}." + ) + def check_shape_and_dtype(func): """ From 4c6803e185428a3bb6b918f472919b7328df99cb Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 00:32:59 +0800 Subject: [PATCH 18/55] add dependencies --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b8b2a01..c0745f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "matplotlib>=3.3.4", "scikit-image>=0.16.2", "tqdm>=4.60.0", + "ivy", + "jaxtyping", + "einops", ] description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} From e902c06293ed958d8c900b1f5546ccc8950ea258 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 01:00:19 +0800 Subject: [PATCH 19/55] supress warning (temp) for numpy 2.0 --- test/test_backend.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index 2fb171f..c8d65b8 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -2,20 +2,25 @@ Test basic usage of ivy and its interaction with numpy and torch. """ +import warnings + +warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=".*numpy.core.numeric is deprecated.*", +) + +import typing +from functools import wraps +from typing import Union + import ivy import numpy as np -import torch -import einops -import camtools as ct -from jaxtyping import Float, UInt8 -import typing import pytest -from numpy.typing import NDArray - -from functools import wraps -from jaxtyping import Float, _array_types +import torch +from jaxtyping import Float, UInt8, _array_types -from typing import Union +import camtools as ct def test_creation(): From c5659d7806901be662fa752d2486346df3119b86 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:07:38 +0800 Subject: [PATCH 20/55] import ivy from backend --- camtools/backend.py | 12 ++++++++++-- test/test_backend.py | 10 +--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index e06d139..8dce3b6 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -1,7 +1,15 @@ -from typing import Literal +import warnings from functools import wraps +from typing import Literal + +# Internally use "from camtools.backend import ivy" to make sure ivy is imported +# after the warnings filter is set. +warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=".*numpy.core.numeric is deprecated.*", +) import ivy -import warnings _default_backend = "numpy" diff --git a/test/test_backend.py b/test/test_backend.py index c8d65b8..de2ee4f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -2,25 +2,17 @@ Test basic usage of ivy and its interaction with numpy and torch. """ -import warnings - -warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message=".*numpy.core.numeric is deprecated.*", -) - import typing from functools import wraps from typing import Union -import ivy import numpy as np import pytest import torch from jaxtyping import Float, UInt8, _array_types import camtools as ct +from camtools.backend import ivy def test_creation(): From abe343b005bfee04d1a3bbb91832446aa3e50c2f Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:14:09 +0800 Subject: [PATCH 21/55] clean up --- test/test_backend.py | 46 -------------------------------------------- 1 file changed, 46 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index de2ee4f..a923924 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -66,52 +66,6 @@ def add(x, y): add(src_x, src_y) -def check_shape_and_dtype(func): - """ - A decorator to enforce type and shape specifications as per type hints. - """ - - def get_shape(dims): - shape = [] - for dim in dims: - if isinstance(dim, _array_types._FixedDim): - shape.append(dim.size) - elif isinstance(dim, _array_types._NamedDim): - shape.append(None) - return tuple(shape) - - @wraps(func) - def wrapper(*args, **kwargs): - hints = typing.get_type_hints(func) - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] - - for arg_name, arg_value in zip(arg_names, args): - if arg_name in hints: - hint = hints[arg_name] - expected_shape = get_shape(hint.dims) - - if not (isinstance(arg_value, (np.ndarray, torch.Tensor))): - raise TypeError(f"{arg_name} must be a tensor") - - if not all( - actual_dim == expected_dim or expected_dim is None - for ( - actual_dim, - expected_dim, - ) in zip( - arg_value.shape, - expected_shape, - ) - ): - raise TypeError( - f"{arg_name} must be a tensor of shape {expected_shape}" - ) - - return func(*args, **kwargs) - - return wrapper - - Tensor = Union[np.ndarray, torch.Tensor] From ca0cc667575f0b588f186694cb9588694f51fc1d Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:24:49 +0800 Subject: [PATCH 22/55] checks dtype --- camtools/sanity.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index cbaa767..90caf1c 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -88,13 +88,18 @@ def assert_tensor_hint(hint, arg, arg_name): ) # Check dtype. - gt_dtypes = unpacked_hints[0].dtypes + gt_dtypes = unpacked_hints[0].dtypes # A tuple of dtype names (str) for unpacked_hint in unpacked_hints: if unpacked_hint.dtypes != gt_dtypes: raise TypeError( f"Internal error: all dtypes in the Union must be the same, " f"but got {gt_dtypes} and {unpacked_hint.dtypes}." ) + if dtype_to_str(arg.dtype) not in gt_dtypes: + raise TypeError( + f"{arg_name} must be a tensor of dtype {gt_dtypes}, " + f"but got dtype {dtype_to_str(arg.dtype)}." + ) def check_shape_and_dtype(func): From 96bc15f67fe2fd3523fa29e4dc5d88f81bd04019 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:27:48 +0800 Subject: [PATCH 23/55] rename --- camtools/backend.py | 7 +++++-- camtools/sanity.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index 8dce3b6..a6ebd69 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -31,8 +31,11 @@ def get_backend() -> str: def with_native_backend(func): """ - Decorator to run a function with: 1) default camtools backend, 2) with - native backend array (setting array mode to False). + Decorator to run a function with: + 1) default camtools backend + 2) with native backend array (setting array mode to False). + + Also, this converts lists to tensors. """ @wraps(func) diff --git a/camtools/sanity.py b/camtools/sanity.py index 90caf1c..17dd99c 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -10,7 +10,7 @@ from typing import Union, Tuple, get_args -def dtype_to_str(dtype): +def _dtype_to_str(dtype): """ Convert numpy or torch dtype to string @@ -41,7 +41,7 @@ def dtype_to_str(dtype): raise ValueError(f"Unknown dtype {dtype}.") -def get_shape( +def _shape_from_dims_str( dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] ) -> Tuple[Union[int, None], ...]: shape = [] @@ -53,7 +53,7 @@ def get_shape( return tuple(shape) -def assert_tensor_hint(hint, arg, arg_name): +def _assert_tensor_hint(hint, arg, arg_name): if getattr(hint, "__origin__", None) is Union: unpacked_hints = get_args(hint) @@ -71,12 +71,12 @@ def assert_tensor_hint(hint, arg, arg_name): ) # Check shapes. - gt_shape = get_shape(unpacked_hints[0].dims) + gt_shape = _shape_from_dims_str(unpacked_hints[0].dims) for unpacked_hint in unpacked_hints: - if get_shape(unpacked_hint.dims) != gt_shape: + if _shape_from_dims_str(unpacked_hint.dims) != gt_shape: raise TypeError( f"Internal error: all shapes in the Union must be the same, " - f"but got {gt_shape} and {get_shape(unpacked_hint.dims)}." + f"but got {gt_shape} and {_shape_from_dims_str(unpacked_hint.dims)}." ) if not all( arg_dim == gt_dim or gt_dim is None @@ -95,10 +95,10 @@ def assert_tensor_hint(hint, arg, arg_name): f"Internal error: all dtypes in the Union must be the same, " f"but got {gt_dtypes} and {unpacked_hint.dtypes}." ) - if dtype_to_str(arg.dtype) not in gt_dtypes: + if _dtype_to_str(arg.dtype) not in gt_dtypes: raise TypeError( f"{arg_name} must be a tensor of dtype {gt_dtypes}, " - f"but got dtype {dtype_to_str(arg.dtype)}." + f"but got dtype {_dtype_to_str(arg.dtype)}." ) @@ -115,7 +115,7 @@ def wrapper(*args, **kwargs): for arg_name, arg in zip(arg_names, args): if arg_name in hints: hint = hints[arg_name] - assert_tensor_hint(hint, arg, arg_name) + _assert_tensor_hint(hint, arg, arg_name) return func(*args, **kwargs) From 9c33b2f536eee5816a121b188a4193c82ffb53e4 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:45:46 +0800 Subject: [PATCH 24/55] check typecheck hints --- camtools/sanity.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index 17dd99c..f3e2cd9 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -8,6 +8,7 @@ from typing import Tuple, Union from typing import Union, Tuple, get_args +import jaxtyping def _dtype_to_str(dtype): @@ -54,12 +55,17 @@ def _shape_from_dims_str( def _assert_tensor_hint(hint, arg, arg_name): - + # Unpack Union types. if getattr(hint, "__origin__", None) is Union: unpacked_hints = get_args(hint) else: unpacked_hints = (hint,) + # If there exists one non jaxtyping hint, skip the check. + for unpacked_hint in unpacked_hints: + if not issubclass(unpacked_hint, jaxtyping.AbstractArray): + return + # Check array types (e.g. np.ndarray, torch.Tensor, ...) valid_array_types = tuple( unpacked_hint.array_type for unpacked_hint in unpacked_hints From f737d531aa110d814e4899cd49e3ac64596fd529 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:47:04 +0800 Subject: [PATCH 25/55] handle kwargs --- camtools/sanity.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/camtools/sanity.py b/camtools/sanity.py index f3e2cd9..43667bd 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -116,9 +116,12 @@ def check_shape_and_dtype(func): @wraps(func) def wrapper(*args, **kwargs): hints = typing.get_type_hints(func) - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + all_args = { + **dict(zip(func.__code__.co_varnames[: func.__code__.co_argcount], args)), + **kwargs, + } - for arg_name, arg in zip(arg_names, args): + for arg_name, arg in all_args.items(): if arg_name in hints: hint = hints[arg_name] _assert_tensor_hint(hint, arg, arg_name) From 2284f9f79138aa7836fc097dacb25edd61e3c39d Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 03:48:27 +0800 Subject: [PATCH 26/55] docs --- camtools/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/camtools/backend.py b/camtools/backend.py index a6ebd69..aaaedfc 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -3,7 +3,8 @@ from typing import Literal # Internally use "from camtools.backend import ivy" to make sure ivy is imported -# after the warnings filter is set. +# after the warnings filter is set. This is a temporary workaround to suppress +# the deprecation warning from numpy 2.0. warnings.filterwarnings( "ignore", category=DeprecationWarning, From a7c00bc03b538872b70798dab69800abd49dd1e7 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 04:00:58 +0800 Subject: [PATCH 27/55] update docs --- camtools/backend.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index aaaedfc..27ab79e 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -32,11 +32,9 @@ def get_backend() -> str: def with_native_backend(func): """ - Decorator to run a function with: - 1) default camtools backend - 2) with native backend array (setting array mode to False). - - Also, this converts lists to tensors. + 1. Enable default camtools backend + 2. Returning native backend array (setting array mode to False). + 3. Converts lists to tensors if the type hint is a tensor. """ @wraps(func) From 9f794034d8021fd59b5fe885a78bcbe7b3b3437a Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 14:42:47 +0800 Subject: [PATCH 28/55] move to typing.py --- camtools/__init__.py | 1 + camtools/backend.py | 12 +++- camtools/sanity.py | 130 ---------------------------------------- camtools/typing.py | 140 +++++++++++++++++++++++++++++++++++++++++++ test/test_backend.py | 7 +-- 5 files changed, 154 insertions(+), 136 deletions(-) create mode 100644 camtools/typing.py diff --git a/camtools/__init__.py b/camtools/__init__.py index 2b0c3e6..14fef02 100644 --- a/camtools/__init__.py +++ b/camtools/__init__.py @@ -15,6 +15,7 @@ from . import sanity from . import solver from . import transform +from . import typing from . import util try: diff --git a/camtools/backend.py b/camtools/backend.py index 27ab79e..b3e7e00 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -1,5 +1,5 @@ import warnings -from functools import wraps +from functools import lru_cache, wraps from typing import Literal # Internally use "from camtools.backend import ivy" to make sure ivy is imported @@ -15,6 +15,16 @@ _default_backend = "numpy" +@lru_cache(maxsize=None) +def is_torch_available(): + try: + import torch + + return True + except ImportError: + return False + + def set_backend(backend: Literal["numpy", "torch"]) -> None: """ Set the default backend for camtools. diff --git a/camtools/sanity.py b/camtools/sanity.py index 43667bd..4fe66f0 100644 --- a/camtools/sanity.py +++ b/camtools/sanity.py @@ -1,134 +1,4 @@ import numpy as np -import numpy as np -import torch -import typing - -from functools import wraps -from jaxtyping import _array_types -from typing import Tuple, Union - -from typing import Union, Tuple, get_args -import jaxtyping - - -def _dtype_to_str(dtype): - """ - Convert numpy or torch dtype to string - - - "bool" - - "bool_" - - "uint4" - - "uint8" - - "uint16" - - "uint32" - - "uint64" - - "int4" - - "int8" - - "int16" - - "int32" - - "int64" - - "bfloat16" - - "float16" - - "float32" - - "float64" - - "complex64" - - "complex128" - """ - if isinstance(dtype, np.dtype): - return dtype.name - elif isinstance(dtype, torch.dtype): - return str(dtype).split(".")[1] - else: - raise ValueError(f"Unknown dtype {dtype}.") - - -def _shape_from_dims_str( - dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] -) -> Tuple[Union[int, None], ...]: - shape = [] - for dim in dims: - if isinstance(dim, _array_types._FixedDim): - shape.append(dim.size) - elif isinstance(dim, _array_types._NamedDim): - shape.append(None) - return tuple(shape) - - -def _assert_tensor_hint(hint, arg, arg_name): - # Unpack Union types. - if getattr(hint, "__origin__", None) is Union: - unpacked_hints = get_args(hint) - else: - unpacked_hints = (hint,) - - # If there exists one non jaxtyping hint, skip the check. - for unpacked_hint in unpacked_hints: - if not issubclass(unpacked_hint, jaxtyping.AbstractArray): - return - - # Check array types (e.g. np.ndarray, torch.Tensor, ...) - valid_array_types = tuple( - unpacked_hint.array_type for unpacked_hint in unpacked_hints - ) - if not isinstance(arg, valid_array_types): - raise TypeError( - f"{arg_name} must be a tensor of type {valid_array_types}, " - f"but got type {type(arg)}." - ) - - # Check shapes. - gt_shape = _shape_from_dims_str(unpacked_hints[0].dims) - for unpacked_hint in unpacked_hints: - if _shape_from_dims_str(unpacked_hint.dims) != gt_shape: - raise TypeError( - f"Internal error: all shapes in the Union must be the same, " - f"but got {gt_shape} and {_shape_from_dims_str(unpacked_hint.dims)}." - ) - if not all( - arg_dim == gt_dim or gt_dim is None - for arg_dim, gt_dim in zip(arg.shape, gt_shape) - ): - raise TypeError( - f"{arg_name} must be a tensor of shape {gt_shape}, " - f"but got shape {arg.shape}." - ) - - # Check dtype. - gt_dtypes = unpacked_hints[0].dtypes # A tuple of dtype names (str) - for unpacked_hint in unpacked_hints: - if unpacked_hint.dtypes != gt_dtypes: - raise TypeError( - f"Internal error: all dtypes in the Union must be the same, " - f"but got {gt_dtypes} and {unpacked_hint.dtypes}." - ) - if _dtype_to_str(arg.dtype) not in gt_dtypes: - raise TypeError( - f"{arg_name} must be a tensor of dtype {gt_dtypes}, " - f"but got dtype {_dtype_to_str(arg.dtype)}." - ) - - -def check_shape_and_dtype(func): - """ - A decorator to enforce type and shape specifications as per type hints. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - hints = typing.get_type_hints(func) - all_args = { - **dict(zip(func.__code__.co_varnames[: func.__code__.co_argcount], args)), - **kwargs, - } - - for arg_name, arg in all_args.items(): - if arg_name in hints: - hint = hints[arg_name] - _assert_tensor_hint(hint, arg, arg_name) - - return func(*args, **kwargs) - - return wrapper def assert_numpy(x, name=None): diff --git a/camtools/typing.py b/camtools/typing.py new file mode 100644 index 0000000..f1e9da5 --- /dev/null +++ b/camtools/typing.py @@ -0,0 +1,140 @@ +import typing +from functools import wraps +from typing import Tuple, Union, get_args + +import jaxtyping +import numpy as np +import torch +from jaxtyping import _array_types + +from . import backend + + +class Tensor: + """ + An abstract tensor type for type hinting only. Typically np.ndarray or + torch.Tensor is supported. + """ + + pass + + +def _dtype_to_str(dtype): + """ + Convert numpy or torch dtype to string + + - "bool" + - "bool_" + - "uint4" + - "uint8" + - "uint16" + - "uint32" + - "uint64" + - "int4" + - "int8" + - "int16" + - "int32" + - "int64" + - "bfloat16" + - "float16" + - "float32" + - "float64" + - "complex64" + - "complex128" + """ + if isinstance(dtype, np.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + return str(dtype).split(".")[1] + else: + raise ValueError(f"Unknown dtype {dtype}.") + + +def _shape_from_dims_str( + dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] +) -> Tuple[Union[int, None], ...]: + shape = [] + for dim in dims: + if isinstance(dim, _array_types._FixedDim): + shape.append(dim.size) + elif isinstance(dim, _array_types._NamedDim): + shape.append(None) + return tuple(shape) + + +def _assert_tensor_hint(hint, arg, arg_name): + # Unpack Union types. + if getattr(hint, "__origin__", None) is Union: + unpacked_hints = get_args(hint) + else: + unpacked_hints = (hint,) + + # If there exists one non jaxtyping hint, skip the check. + for unpacked_hint in unpacked_hints: + if not issubclass(unpacked_hint, jaxtyping.AbstractArray): + return + + # Check array types. + if backend.is_torch_available(): + valid_array_types = (np.ndarray, torch.Tensor) + else: + valid_array_types = (np.ndarray,) + if not isinstance(arg, valid_array_types): + raise TypeError( + f"{arg_name} must be a tensor of type {valid_array_types}, " + f"but got type {type(arg)}." + ) + + # Check shapes. + gt_shape = _shape_from_dims_str(unpacked_hints[0].dims) + for unpacked_hint in unpacked_hints: + if _shape_from_dims_str(unpacked_hint.dims) != gt_shape: + raise TypeError( + f"Internal error: all shapes in the Union must be the same, " + f"but got {gt_shape} and {_shape_from_dims_str(unpacked_hint.dims)}." + ) + if not all( + arg_dim == gt_dim or gt_dim is None + for arg_dim, gt_dim in zip(arg.shape, gt_shape) + ): + raise TypeError( + f"{arg_name} must be a tensor of shape {gt_shape}, " + f"but got shape {arg.shape}." + ) + + # Check dtype. + gt_dtypes = unpacked_hints[0].dtypes # A tuple of dtype names (str) + for unpacked_hint in unpacked_hints: + if unpacked_hint.dtypes != gt_dtypes: + raise TypeError( + f"Internal error: all dtypes in the Union must be the same, " + f"but got {gt_dtypes} and {unpacked_hint.dtypes}." + ) + if _dtype_to_str(arg.dtype) not in gt_dtypes: + raise TypeError( + f"{arg_name} must be a tensor of dtype {gt_dtypes}, " + f"but got dtype {_dtype_to_str(arg.dtype)}." + ) + + +def check_shape_and_dtype(func): + """ + A decorator to enforce type and shape specifications as per type hints. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + hints = typing.get_type_hints(func) + all_args = { + **dict(zip(func.__code__.co_varnames[: func.__code__.co_argcount], args)), + **kwargs, + } + + for arg_name, arg in all_args.items(): + if arg_name in hints: + hint = hints[arg_name] + _assert_tensor_hint(hint, arg, arg_name) + + return func(*args, **kwargs) + + return wrapper diff --git a/test/test_backend.py b/test/test_backend.py index a923924..b22bbfb 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -12,7 +12,7 @@ from jaxtyping import Float, UInt8, _array_types import camtools as ct -from camtools.backend import ivy +from camtools.backend import ivy, Tensor def test_creation(): @@ -66,16 +66,13 @@ def add(x, y): add(src_x, src_y) -Tensor = Union[np.ndarray, torch.Tensor] - - def test_type_hint_arguments(): """ Test type hinting arguments. """ @ct.backend.with_native_backend - @ct.sanity.check_shape_and_dtype + @ct.typing.check_shape_and_dtype def add( x: Float[Tensor, "2 3"], y: Float[Tensor, "1 3"], From c491aa4a0e5d28d511c5cfc50cc3db4bd419fa14 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 14:45:16 +0800 Subject: [PATCH 29/55] remove union type checks --- camtools/typing.py | 27 ++++----------------------- test/test_backend.py | 3 ++- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index f1e9da5..d7fdc26 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -63,16 +63,9 @@ def _shape_from_dims_str( def _assert_tensor_hint(hint, arg, arg_name): - # Unpack Union types. - if getattr(hint, "__origin__", None) is Union: - unpacked_hints = get_args(hint) - else: - unpacked_hints = (hint,) - # If there exists one non jaxtyping hint, skip the check. - for unpacked_hint in unpacked_hints: - if not issubclass(unpacked_hint, jaxtyping.AbstractArray): - return + if not issubclass(hint, jaxtyping.AbstractArray): + return # Check array types. if backend.is_torch_available(): @@ -86,13 +79,7 @@ def _assert_tensor_hint(hint, arg, arg_name): ) # Check shapes. - gt_shape = _shape_from_dims_str(unpacked_hints[0].dims) - for unpacked_hint in unpacked_hints: - if _shape_from_dims_str(unpacked_hint.dims) != gt_shape: - raise TypeError( - f"Internal error: all shapes in the Union must be the same, " - f"but got {gt_shape} and {_shape_from_dims_str(unpacked_hint.dims)}." - ) + gt_shape = _shape_from_dims_str(hint.dims) if not all( arg_dim == gt_dim or gt_dim is None for arg_dim, gt_dim in zip(arg.shape, gt_shape) @@ -103,13 +90,7 @@ def _assert_tensor_hint(hint, arg, arg_name): ) # Check dtype. - gt_dtypes = unpacked_hints[0].dtypes # A tuple of dtype names (str) - for unpacked_hint in unpacked_hints: - if unpacked_hint.dtypes != gt_dtypes: - raise TypeError( - f"Internal error: all dtypes in the Union must be the same, " - f"but got {gt_dtypes} and {unpacked_hint.dtypes}." - ) + gt_dtypes = hint.dtypes if _dtype_to_str(arg.dtype) not in gt_dtypes: raise TypeError( f"{arg_name} must be a tensor of dtype {gt_dtypes}, " diff --git a/test/test_backend.py b/test/test_backend.py index b22bbfb..1ca5993 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -12,7 +12,8 @@ from jaxtyping import Float, UInt8, _array_types import camtools as ct -from camtools.backend import ivy, Tensor +from camtools.backend import ivy +from camtools.typing import Tensor def test_creation(): From 955f72c0c3f059e45c69087cc7743a59b8bc5c88 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 14:58:10 +0800 Subject: [PATCH 30/55] docs --- camtools/typing.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index d7fdc26..76a3617 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -1,6 +1,6 @@ import typing from functools import wraps -from typing import Tuple, Union, get_args +from typing import Tuple, Union, Any import jaxtyping import numpy as np @@ -62,11 +62,17 @@ def _shape_from_dims_str( return tuple(shape) -def _assert_tensor_hint(hint, arg, arg_name): - # If there exists one non jaxtyping hint, skip the check. - if not issubclass(hint, jaxtyping.AbstractArray): - return - +def _assert_tensor_hint( + hint: jaxtyping.AbstractArray, + arg: Any, + arg_name: str, +): + """ + Args: + hint: A type hint for a tensor, must be javtyping.AbstractArray. + arg: An argument to check, typically a tensor. + arg_name: The name of the argument, for error messages. + """ # Check array types. if backend.is_torch_available(): valid_array_types = (np.ndarray, torch.Tensor) @@ -110,12 +116,10 @@ def wrapper(*args, **kwargs): **dict(zip(func.__code__.co_varnames[: func.__code__.co_argcount], args)), **kwargs, } - for arg_name, arg in all_args.items(): if arg_name in hints: hint = hints[arg_name] _assert_tensor_hint(hint, arg, arg_name) - return func(*args, **kwargs) return wrapper From d9b9b73a6717c29334235d0c4ff6d166513e5a09 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:06:23 +0800 Subject: [PATCH 31/55] better code inspection --- camtools/typing.py | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index 76a3617..91b07ad 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -7,6 +7,8 @@ import torch from jaxtyping import _array_types +from inspect import signature, Parameter + from . import backend @@ -111,15 +113,40 @@ def check_shape_and_dtype(func): @wraps(func) def wrapper(*args, **kwargs): - hints = typing.get_type_hints(func) - all_args = { + arg_name_to_hint = typing.get_type_hints(func) + arg_name_to_arg = { **dict(zip(func.__code__.co_varnames[: func.__code__.co_argcount], args)), **kwargs, } - for arg_name, arg in all_args.items(): - if arg_name in hints: - hint = hints[arg_name] + + for arg_name, arg in arg_name_to_arg.items(): + if arg_name in arg_name_to_hint: + hint = arg_name_to_hint[arg_name] + _assert_tensor_hint(hint, arg, arg_name) + return func(*args, **kwargs) + + return wrapper + + +def check_shape_and_dtype(func): + """ + A decorator to enforce type and shape specifications as per type hints. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + sig = signature(func) + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + + arg_name_to_arg = bound_args.arguments + arg_name_to_hint = typing.get_type_hints(func) + + for arg_name, arg in arg_name_to_arg.items(): + if arg_name in arg_name_to_hint: + hint = arg_name_to_hint[arg_name] _assert_tensor_hint(hint, arg, arg_name) + return func(*args, **kwargs) return wrapper From f5c70168cad9afe86625d49a9420e9013bd19e48 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:06:40 +0800 Subject: [PATCH 32/55] better code inspection --- camtools/typing.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index 91b07ad..c47dd71 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -106,28 +106,6 @@ def _assert_tensor_hint( ) -def check_shape_and_dtype(func): - """ - A decorator to enforce type and shape specifications as per type hints. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - arg_name_to_hint = typing.get_type_hints(func) - arg_name_to_arg = { - **dict(zip(func.__code__.co_varnames[: func.__code__.co_argcount], args)), - **kwargs, - } - - for arg_name, arg in arg_name_to_arg.items(): - if arg_name in arg_name_to_hint: - hint = arg_name_to_hint[arg_name] - _assert_tensor_hint(hint, arg, arg_name) - return func(*args, **kwargs) - - return wrapper - - def check_shape_and_dtype(func): """ A decorator to enforce type and shape specifications as per type hints. From 4d17d932cf9313bc49c70bf28b1288f7961625db Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:20:07 +0800 Subject: [PATCH 33/55] comments --- camtools/backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/camtools/backend.py b/camtools/backend.py index b3e7e00..7006655 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -1,6 +1,7 @@ import warnings from functools import lru_cache, wraps from typing import Literal +from inspect import signature, Parameter # Internally use "from camtools.backend import ivy" to make sure ivy is imported # after the warnings filter is set. This is a temporary workaround to suppress @@ -53,6 +54,8 @@ def wrapper(*args, **kwargs): ct_backend = get_backend() ivy.set_backend(ct_backend) try: + # If type hint is a tensor, convert (nested) list to tensor. + with warnings.catch_warnings(): """ Possible warning: From cd2bc926e13852ddeeed9b023ac2a639c57f8e60 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:23:05 +0800 Subject: [PATCH 34/55] issubclass(hint, jaxtyping.AbstractArray) --- camtools/typing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/camtools/typing.py b/camtools/typing.py index c47dd71..54cc6d8 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -123,7 +123,8 @@ def wrapper(*args, **kwargs): for arg_name, arg in arg_name_to_arg.items(): if arg_name in arg_name_to_hint: hint = arg_name_to_hint[arg_name] - _assert_tensor_hint(hint, arg, arg_name) + if issubclass(hint, jaxtyping.AbstractArray): + _assert_tensor_hint(hint, arg, arg_name) return func(*args, **kwargs) From 512aec86d7a1e1063d17e424fefef47f2262c4f2 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:40:20 +0800 Subject: [PATCH 35/55] initial list->tensor auto convert, dtype check has issues --- camtools/backend.py | 35 +++++++++++++++++++++++------------ test/test_backend.py | 23 +++++++++++++++++++---- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index 7006655..f55287d 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -1,7 +1,11 @@ +import typing import warnings from functools import lru_cache, wraps +from inspect import signature from typing import Literal -from inspect import signature, Parameter +import jaxtyping + +import ivy # Internally use "from camtools.backend import ivy" to make sure ivy is imported # after the warnings filter is set. This is a temporary workaround to suppress @@ -43,8 +47,8 @@ def get_backend() -> str: def with_native_backend(func): """ - 1. Enable default camtools backend - 2. Returning native backend array (setting array mode to False). + 1. Enable default camtools backend. + 2. Return native backend array (setting array mode to False). 3. Converts lists to tensors if the type hint is a tensor. """ @@ -53,19 +57,26 @@ def wrapper(*args, **kwargs): og_backend = ivy.current_backend() ct_backend = get_backend() ivy.set_backend(ct_backend) - try: - # If type hint is a tensor, convert (nested) list to tensor. + # Unpack args and type hints + sig = signature(func) + bound_args = sig.bind(*args, **kwargs) + arg_name_to_hint = typing.get_type_hints(func) + + try: with warnings.catch_warnings(): - """ - Possible warning: - UserWarning: In the case of Compositional function, operators - might cause inconsistent behavior when array_mode is set to - False. - """ warnings.simplefilter("ignore", category=UserWarning) with ivy.ArrayMode(False): - result = func(*args, **kwargs) + # Convert list -> native tensor if the type hint is a tensor + for arg_name, arg in bound_args.arguments.items(): + if arg_name in arg_name_to_hint and issubclass( + arg_name_to_hint[arg_name], jaxtyping.AbstractArray + ): + if isinstance(arg, list): + bound_args.arguments[arg_name] = ivy.native_array(arg) + + # Call the function + result = func(*bound_args.args, **bound_args.kwargs) finally: ivy.set_backend(og_backend) return result diff --git a/test/test_backend.py b/test/test_backend.py index 1ca5993..fc5bef1 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -87,12 +87,27 @@ def add( expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) assert np.allclose(result, expected, atol=1e-5) - # Testing with incorrect shapes + # List can be converted to numpy automatically + x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + result = add(x, y) + assert np.allclose(result, expected, atol=1e-5) + + # Incorrect shapes with pytest.raises(TypeError): y_wrong = np.array([[1, 1, 1, 1]], dtype=np.float32) add(x, y_wrong) - # Testing with incorrect types + # Incorrect shape with lists + with pytest.raises(TypeError): + y_wrong = [[1.0, 1.0, 1.0, 1.0]] + add(x, y_wrong) + + # Incorrect dtype with pytest.raises(TypeError): - x_wrong_type = [[1, 2, 3], [4, 5, 6]] # not a NumPy array - add(x_wrong_type, y) + y_wrong = np.array([[1, 1, 1]], dtype=np.int64) + add(x, y_wrong) + + # Incorrect dtype with lists + with pytest.raises(TypeError): + y_wrong = [[1, 1, 1]] + add(x, y_wrong) From 79ec94644d50f04301c24619ebe01aeb3a59b9a2 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:44:43 +0800 Subject: [PATCH 36/55] dtype test working --- test/test_backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index fc5bef1..8f47084 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -93,21 +93,21 @@ def add( assert np.allclose(result, expected, atol=1e-5) # Incorrect shapes - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=r".*but got shape.*"): y_wrong = np.array([[1, 1, 1, 1]], dtype=np.float32) add(x, y_wrong) # Incorrect shape with lists - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=r".*but got shape.*"): y_wrong = [[1.0, 1.0, 1.0, 1.0]] add(x, y_wrong) # Incorrect dtype - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=r".*but got dtype.*"): y_wrong = np.array([[1, 1, 1]], dtype=np.int64) add(x, y_wrong) # Incorrect dtype with lists - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=r".*but got dtype.*"): y_wrong = [[1, 1, 1]] add(x, y_wrong) From 57b3d5c4b196f00bacdf025e38f89d71e2ede79a Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:47:15 +0800 Subject: [PATCH 37/55] test typing --- test/{test_backend.py => test_typing.py} | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) rename test/{test_backend.py => test_typing.py} (68%) diff --git a/test/test_backend.py b/test/test_typing.py similarity index 68% rename from test/test_backend.py rename to test/test_typing.py index 8f47084..c8746b8 100644 --- a/test/test_backend.py +++ b/test/test_typing.py @@ -85,11 +85,13 @@ def add( y = np.array([[1, 1, 1]], dtype=np.float32) result = add(x, y) expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) + assert isinstance(result, np.ndarray) assert np.allclose(result, expected, atol=1e-5) # List can be converted to numpy automatically x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] result = add(x, y) + assert isinstance(result, np.ndarray) assert np.allclose(result, expected, atol=1e-5) # Incorrect shapes @@ -111,3 +113,40 @@ def add( with pytest.raises(TypeError, match=r".*but got dtype.*"): y_wrong = [[1, 1, 1]] add(x, y_wrong) + + # With torch backend + ct.backend.set_backend("torch") + x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32) + y = torch.tensor([[1, 1, 1]], dtype=torch.float32) + result = add(x, y) + expected = torch.tensor([[2, 3, 4], [5, 6, 7]], dtype=torch.float32) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # List can be converted to torch automatically + x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + result = add(x, y) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Incorrect shapes + with pytest.raises(TypeError, match=r".*but got shape.*"): + y_wrong = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32) + add(x, y_wrong) + + # Incorrect shape with lists + with pytest.raises(TypeError, match=r".*but got shape.*"): + y_wrong = [[1.0, 1.0, 1.0, 1.0]] + add(x, y_wrong) + + # Incorrect dtype + with pytest.raises(TypeError, match=r".*but got dtype.*"): + y_wrong = torch.tensor([[1, 1, 1]], dtype=torch.int64) + add(x, y_wrong) + + # Incorrect dtype with lists + with pytest.raises(TypeError, match=r".*but got dtype.*"): + y_wrong = [[1, 1, 1]] + add(x, y_wrong) + + ct.backend.set_backend("numpy") From abddd7e1fb388f9b347984d91938da1f8fc50fff Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:50:15 +0800 Subject: [PATCH 38/55] format --- camtools/typing.py | 5 ++--- test/test_typing.py | 10 +--------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index 54cc6d8..494e5d1 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -1,14 +1,13 @@ import typing from functools import wraps -from typing import Tuple, Union, Any +from inspect import signature +from typing import Any, Tuple, Union import jaxtyping import numpy as np import torch from jaxtyping import _array_types -from inspect import signature, Parameter - from . import backend diff --git a/test/test_typing.py b/test/test_typing.py index c8746b8..562397f 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,15 +1,7 @@ -""" -Test basic usage of ivy and its interaction with numpy and torch. -""" - -import typing -from functools import wraps -from typing import Union - import numpy as np import pytest import torch -from jaxtyping import Float, UInt8, _array_types +from jaxtyping import Float import camtools as ct from camtools.backend import ivy From 5a1b60c66dafcb631e295e45735babb0b6253021 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:54:13 +0800 Subject: [PATCH 39/55] up dependencies --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c0745f8..bfd3400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,9 @@ dependencies = [ "matplotlib>=3.3.4", "scikit-image>=0.16.2", "tqdm>=4.60.0", - "ivy", - "jaxtyping", - "einops", + "ivy>=0.0.8.0", + "jaxtyping>=0.2.28", + "einops>=0.8.0", ] description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} From 2cc67503de5493f8a139a866fd327bb94610b485 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 15:56:26 +0800 Subject: [PATCH 40/55] undo change --- test/test_convert.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/test/test_convert.py b/test/test_convert.py index f229a21..2d28ce0 100644 --- a/test/test_convert.py +++ b/test/test_convert.py @@ -5,7 +5,7 @@ np.set_printoptions(formatter={"float": "{: 0.2f}".format}) -def test_R_t_to_C(): +def test_R_t_to_cameracenter(): T = np.array( [ [0.132521, 0.00567408, 0.991163, 0.0228366], @@ -23,16 +23,6 @@ def test_R_t_to_C(): ) -def R_t_to_C(R, t): - """ - Convert R, t to camera center - """ - t = t.reshape(-1, 3, 1) - R = R.reshape(-1, 3, 3) - C = -R.transpose(0, 2, 1) @ t - return C.squeeze() - - def test_P_to_K_R_t(): def P_to_K_R_t_manual(P): """ From 13d40b37740dedd48bbc560aeb4aa5cbf3be414f Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 16:04:57 +0800 Subject: [PATCH 41/55] fix dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bfd3400..61bf76f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "scikit-image>=0.16.2", "tqdm>=4.60.0", "ivy>=0.0.8.0", - "jaxtyping>=0.2.28", + "jaxtyping>=0.2.18", "einops>=0.8.0", ] description = "CamTools: Camera Tools for Computer Vision." From 542d3d1e1037bfdeff5b1aa979029ae733b152bd Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 16:07:26 +0800 Subject: [PATCH 42/55] change version --- .github/workflows/unit_test.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 9b02d84..96a6fa0 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/pyproject.toml b/pyproject.toml index 61bf76f..0df4ccb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "scikit-image>=0.16.2", "tqdm>=4.60.0", "ivy>=0.0.8.0", - "jaxtyping>=0.2.18", + "jaxtyping>=0.2.12", "einops>=0.8.0", ] description = "CamTools: Camera Tools for Computer Vision." From 2dd4b1259378bb1297e278ed015b8e7fb1c488dc Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 16:09:21 +0800 Subject: [PATCH 43/55] torch --- camtools/typing.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index 494e5d1..d526e8d 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -5,7 +5,6 @@ import jaxtyping import numpy as np -import torch from jaxtyping import _array_types from . import backend @@ -13,8 +12,8 @@ class Tensor: """ - An abstract tensor type for type hinting only. Typically np.ndarray or - torch.Tensor is supported. + An abstract tensor type for type hinting only. + Typically np.ndarray or torch.Tensor is supported. """ pass @@ -45,10 +44,14 @@ def _dtype_to_str(dtype): """ if isinstance(dtype, np.dtype): return dtype.name - elif isinstance(dtype, torch.dtype): - return str(dtype).split(".")[1] - else: - raise ValueError(f"Unknown dtype {dtype}.") + + if backend.is_torch_available(): + import torch + + if isinstance(dtype, torch.dtype): + return str(dtype).split(".")[1] + + return ValueError(f"Unknown dtype {dtype}.") def _shape_from_dims_str( @@ -76,6 +79,8 @@ def _assert_tensor_hint( """ # Check array types. if backend.is_torch_available(): + import torch + valid_array_types = (np.ndarray, torch.Tensor) else: valid_array_types = (np.ndarray,) From 098f9ddbe6aeef3258ec478d06f3205e32edf372 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 16:14:28 +0800 Subject: [PATCH 44/55] separate torch test with numpy test --- test/test_typing.py | 59 +++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/test/test_typing.py b/test/test_typing.py index 562397f..aa8dc02 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,18 +1,13 @@ import numpy as np import pytest -import torch from jaxtyping import Float import camtools as ct -from camtools.backend import ivy +from camtools.backend import ivy, is_torch_available from camtools.typing import Tensor -def test_creation(): - """ - Test tensor creation. - """ - +def test_creation_numpy(): @ct.backend.with_native_backend def creation(): zeros = ivy.zeros([2, 3]) @@ -25,6 +20,16 @@ def creation(): assert tensor.shape == (2, 3) assert tensor.dtype == np.float32 + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_creation_torch(): + import torch + + @ct.backend.with_native_backend + def creation(): + zeros = ivy.zeros([2, 3]) + return zeros + # Switch to torch backend ct.backend.set_backend("torch") assert ct.backend.get_backend() == "torch" @@ -35,10 +40,23 @@ def creation(): ct.backend.set_backend("numpy") -def test_arguments(): - """ - Test taking arguments from functions. - """ +def test_arguments_numpy(): + @ct.backend.with_native_backend + def add(x, y): + return x + y + + # Default backend is numpy + assert ct.backend.get_backend() == "numpy" + src_x = np.ones([2, 3]) * 2 + src_y = np.ones([1, 3]) * 3 + dst_expected = np.ones([2, 3]) * 5 + dst = add(src_x, src_y) + np.testing.assert_allclose(dst, dst_expected, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_arguments_torch(): + import torch @ct.backend.with_native_backend def add(x, y): @@ -59,11 +77,7 @@ def add(x, y): add(src_x, src_y) -def test_type_hint_arguments(): - """ - Test type hinting arguments. - """ - +def test_type_hint_arguments_numpy(): @ct.backend.with_native_backend @ct.typing.check_shape_and_dtype def add( @@ -106,6 +120,19 @@ def add( y_wrong = [[1, 1, 1]] add(x, y_wrong) + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_type_hint_arguments_torch(): + import torch + + @ct.backend.with_native_backend + @ct.typing.check_shape_and_dtype + def add( + x: Float[Tensor, "2 3"], + y: Float[Tensor, "1 3"], + ) -> Float[Tensor, "2 3"]: + return x + y + # With torch backend ct.backend.set_backend("torch") x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32) From dea9ffa183a294469f455c8883ee1e6dba1be3ad Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 16:15:49 +0800 Subject: [PATCH 45/55] update tests --- .github/workflows/unit_test.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 96a6fa0..5fbab30 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -25,6 +25,12 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[dev] - - name: Run unit tests + - name: Run unit tests (numpy only) + run: | + pytest + - name: Install dependencies with torch + run: | + pip install -e .[torch] + - name: Run unit tests (with torch) run: | pytest From 3e5e33b0cbd8e542fa9fe38544f70d2c84b0e1c1 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 16:20:46 +0800 Subject: [PATCH 46/55] reduce version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0df4ccb..026d129 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "tqdm>=4.60.0", "ivy>=0.0.8.0", "jaxtyping>=0.2.12", - "einops>=0.8.0", + "einops>=0.6.1", ] description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} From cdd8be20ec8d066e99cc361197eb831247b2c344 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 17:57:22 +0800 Subject: [PATCH 47/55] test_named_dim_numpy and conditionally import jaxtyping internal modules --- .github/workflows/unit_test.yml | 2 +- camtools/typing.py | 29 +++++++++++++++++----- pyproject.toml | 2 +- test/test_typing.py | 43 +++++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 8 deletions(-) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 5fbab30..b1d81f4 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/camtools/typing.py b/camtools/typing.py index d526e8d..3ddb8a6 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -1,15 +1,29 @@ import typing +from distutils.version import LooseVersion from functools import wraps from inspect import signature from typing import Any, Tuple, Union import jaxtyping import numpy as np -from jaxtyping import _array_types +import pkg_resources from . import backend +try: + _jaxtyping_version = pkg_resources.get_distribution("jaxtyping").version + if LooseVersion(_jaxtyping_version) >= LooseVersion("0.2.20"): + from jaxtyping._array_types import _FixedDim, _NamedDim + else: + from jaxtyping.array_types import _FixedDim, _NamedDim +except ImportError: + raise ImportError( + f"Failed to import _FixedDim and _NamedDim. with " + f"jaxtyping version {_jaxtyping_version}." + ) + + class Tensor: """ An abstract tensor type for type hinting only. @@ -54,14 +68,17 @@ def _dtype_to_str(dtype): return ValueError(f"Unknown dtype {dtype}.") -def _shape_from_dims_str( - dims: Tuple[Union[_array_types._FixedDim, _array_types._NamedDim], ...] +def _shape_from_dims( + dims: Tuple[ + Union[_FixedDim, _NamedDim], + ..., + ] ) -> Tuple[Union[int, None], ...]: shape = [] for dim in dims: - if isinstance(dim, _array_types._FixedDim): + if isinstance(dim, _FixedDim): shape.append(dim.size) - elif isinstance(dim, _array_types._NamedDim): + elif isinstance(dim, _NamedDim): shape.append(None) return tuple(shape) @@ -91,7 +108,7 @@ def _assert_tensor_hint( ) # Check shapes. - gt_shape = _shape_from_dims_str(hint.dims) + gt_shape = _shape_from_dims(hint.dims) if not all( arg_dim == gt_dim or gt_dim is None for arg_dim, gt_dim in zip(arg.shape, gt_shape) diff --git a/pyproject.toml b/pyproject.toml index 026d129..bdd2ede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} name = "camtools" readme = "README.md" -requires-python = ">=3.6.0" +requires-python = ">=3.8.0" version = "0.1.4" [project.scripts] diff --git a/test/test_typing.py b/test/test_typing.py index aa8dc02..ed74120 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -169,3 +169,46 @@ def add( add(x, y_wrong) ct.backend.set_backend("numpy") + + +def test_named_dim_numpy(): + @ct.backend.with_native_backend + @ct.typing.check_shape_and_dtype + def add( + x: Float[Tensor, "3"], + y: Float[Tensor, "n 3"], + ) -> Float[Tensor, "n 3"]: + return x + y + + # Fixed x tensor + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + # Valid y tensor with shape (1, 3) + y = np.array([[4.0, 5.0, 6.0]], dtype=np.float32) + result = add(x, y) + expected = np.array([[5.0, 7.0, 9.0]], dtype=np.float32) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Valid y tensor with shape (2, 3) + y = np.array([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float32) + result = add(x, y) + expected = np.array([[5.0, 7.0, 9.0], [8.0, 10.0, 12.0]], dtype=np.float32) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Test for a shape mismatch where y does not conform to "n 3" + with pytest.raises(TypeError, match=r".*but got shape \(3,\).*"): + y_wrong = np.array([4.0, 5.0, 6.0], dtype=np.float32) # Shape (3,) + add(x, y_wrong) + + # List inputs that should be automatically converted and work + y = [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + result = add(x, y) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Incorrect dtype with lists, expect dtype error + with pytest.raises(TypeError, match=r".*but got dtype.*"): + y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list + add(x, y_wrong) From 2cae122b019fd304a26fb2be8a21370c7a588ab0 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 17:58:46 +0800 Subject: [PATCH 48/55] comments --- camtools/typing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/camtools/typing.py b/camtools/typing.py index 3ddb8a6..65356e2 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -12,6 +12,9 @@ try: + # Try to import _FixedDim and _NamedDim from jaxtyping. This are internal + # classes where the APIs are not stable. There are no guarantees that these + # classes will be available in future versions of jaxtyping. _jaxtyping_version = pkg_resources.get_distribution("jaxtyping").version if LooseVersion(_jaxtyping_version) >= LooseVersion("0.2.20"): from jaxtyping._array_types import _FixedDim, _NamedDim From 7971605a1b77576a90c25362934c54ed35f1e0c7 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:02:15 +0800 Subject: [PATCH 49/55] numpy test working --- camtools/typing.py | 18 ++++++++++++++---- test/test_typing.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/camtools/typing.py b/camtools/typing.py index 65356e2..8d94b24 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -86,6 +86,19 @@ def _shape_from_dims( return tuple(shape) +def _is_shape_compatible( + arg_shape: Tuple[Union[int, None], ...], + gt_shape: Tuple[Union[int, None], ...], +) -> bool: + if len(arg_shape) != len(gt_shape): + return False + + return all( + arg_dim == gt_dim or gt_dim is None + for arg_dim, gt_dim in zip(arg_shape, gt_shape) + ) + + def _assert_tensor_hint( hint: jaxtyping.AbstractArray, arg: Any, @@ -112,10 +125,7 @@ def _assert_tensor_hint( # Check shapes. gt_shape = _shape_from_dims(hint.dims) - if not all( - arg_dim == gt_dim or gt_dim is None - for arg_dim, gt_dim in zip(arg.shape, gt_shape) - ): + if not _is_shape_compatible(arg.shape, gt_shape): raise TypeError( f"{arg_name} must be a tensor of shape {gt_shape}, " f"but got shape {arg.shape}." diff --git a/test/test_typing.py b/test/test_typing.py index ed74120..45a4fa4 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -198,7 +198,7 @@ def add( assert np.allclose(result, expected, atol=1e-5) # Test for a shape mismatch where y does not conform to "n 3" - with pytest.raises(TypeError, match=r".*but got shape \(3,\).*"): + with pytest.raises(TypeError, match=r".*but got shape.*"): y_wrong = np.array([4.0, 5.0, 6.0], dtype=np.float32) # Shape (3,) add(x, y_wrong) From ce51202d4aa95c2f5974b76e14d51f5921438d61 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:04:36 +0800 Subject: [PATCH 50/55] add torch test --- test/test_typing.py | 52 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/test/test_typing.py b/test/test_typing.py index 45a4fa4..cb16396 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -21,7 +21,7 @@ def creation(): assert tensor.dtype == np.float32 -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") def test_creation_torch(): import torch @@ -54,7 +54,7 @@ def add(x, y): np.testing.assert_allclose(dst, dst_expected, rtol=1e-5, atol=1e-5) -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") def test_arguments_torch(): import torch @@ -121,7 +121,7 @@ def add( add(x, y_wrong) -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") def test_type_hint_arguments_torch(): import torch @@ -212,3 +212,49 @@ def add( with pytest.raises(TypeError, match=r".*but got dtype.*"): y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list add(x, y_wrong) + + +@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") +def test_named_dim_torch(): + import torch + + @ct.backend.with_native_backend + @ct.typing.check_shape_and_dtype + def add( + x: Float[Tensor, "3"], + y: Float[Tensor, "n 3"], + ) -> Float[Tensor, "n 3"]: + return x + y + + # Fixed x tensor for Torch + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + # Valid y tensor with shape (1, 3) + y = torch.tensor([[4.0, 5.0, 6.0]], dtype=torch.float32) + result = add(x, y) + expected = torch.tensor([[5.0, 7.0, 9.0]], dtype=torch.float32) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Valid y tensor with shape (2, 3) + y = torch.tensor([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=torch.float32) + result = add(x, y) + expected = torch.tensor([[5.0, 7.0, 9.0], [8.0, 10.0, 12.0]], dtype=torch.float32) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Test for a shape mismatch where y does not conform to "n 3" + with pytest.raises(TypeError, match=r".*but got shape.*"): + y_wrong = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32) # Shape (3,) + add(x, y_wrong) + + # List inputs that should be automatically converted and work + y = [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + result = add(x, y) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Incorrect dtype with lists, expect dtype error + with pytest.raises(TypeError, match=r".*but got dtype.*"): + y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list + add(x, y_wrong) From 3fb998e3c27b8ba35998f5ea8bec6837d2873a8a Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:05:59 +0800 Subject: [PATCH 51/55] update version --- .github/workflows/pypi.yml | 2 +- .github/workflows/unit_test.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index baa7130..9cf7920 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index b1d81f4..1634eb1 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -25,12 +25,12 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[dev] - - name: Run unit tests (numpy only) + - name: Run unit tests (numpy) run: | pytest - name: Install dependencies with torch run: | pip install -e .[torch] - - name: Run unit tests (with torch) + - name: Run unit tests (numpy + torch) run: | pytest From dd0e8e99d869e231774fab14fbfd8ec710f7892a Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:07:25 +0800 Subject: [PATCH 52/55] switch back after creation --- test/test_typing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_typing.py b/test/test_typing.py index cb16396..b396480 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -39,6 +39,11 @@ def creation(): assert tensor.dtype == torch.float32 ct.backend.set_backend("numpy") + # Now, the default backend is numpy + assert ct.backend.get_backend() == "numpy" + tensor = creation() + assert isinstance(tensor, np.ndarray) + def test_arguments_numpy(): @ct.backend.with_native_backend From 557dae0745c4d18490d2525d7de7493ad036a515 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:17:50 +0800 Subject: [PATCH 53/55] remove einops dep --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bdd2ede..dfa6d81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ "tqdm>=4.60.0", "ivy>=0.0.8.0", "jaxtyping>=0.2.12", - "einops>=0.6.1", ] description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} From 9d3b98a1f780d0c4f52140231d4d5555d57a8ea0 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:21:29 +0800 Subject: [PATCH 54/55] organize imports --- camtools/backend.py | 2 +- camtools/typing.py | 1 - test/test_typing.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/camtools/backend.py b/camtools/backend.py index f55287d..3a7d290 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -3,9 +3,9 @@ from functools import lru_cache, wraps from inspect import signature from typing import Literal -import jaxtyping import ivy +import jaxtyping # Internally use "from camtools.backend import ivy" to make sure ivy is imported # after the warnings filter is set. This is a temporary workaround to suppress diff --git a/camtools/typing.py b/camtools/typing.py index 8d94b24..b4fff79 100644 --- a/camtools/typing.py +++ b/camtools/typing.py @@ -10,7 +10,6 @@ from . import backend - try: # Try to import _FixedDim and _NamedDim from jaxtyping. This are internal # classes where the APIs are not stable. There are no guarantees that these diff --git a/test/test_typing.py b/test/test_typing.py index b396480..0b43c4b 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -3,7 +3,7 @@ from jaxtyping import Float import camtools as ct -from camtools.backend import ivy, is_torch_available +from camtools.backend import ivy from camtools.typing import Tensor From a99776869ea01af91cf5bb23725c1ee72fcd9ed1 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Mon, 1 Jul 2024 18:25:30 +0800 Subject: [PATCH 55/55] build: python version support update to 3.8 - 3.11 --- .github/workflows/pypi.yml | 2 +- .github/workflows/unit_test.yml | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index baa7130..9cf7920 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 9b02d84..7028b80 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/pyproject.toml b/pyproject.toml index b8b2a01..698c58c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} name = "camtools" readme = "README.md" -requires-python = ">=3.6.0" +requires-python = ">=3.8.0" version = "0.1.4" [project.scripts]