Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial infrastructures for multi-backend support #62

Merged
merged 56 commits into from
Jul 1, 2024
Merged

Conversation

yxlao
Copy link
Owner

@yxlao yxlao commented Jul 1, 2024

This PR provides initial infrastructures for multi-backend support and type checking.

Design goals

Goals for multi-backend support

  1. All functions in camtools shall be able to process and return both numpy array and torch tensors. This is done automatically by decorators.
  2. Torch shall NOT be a compulsory dependency of camtools, camtools can detected if torch is installed, and enable the corresponding backends.
  3. By default, the backend is set to numpy. You may set the default backend to torch by set_backend().

Goals for type checking

  1. Functions in camtools shall be type hinted with the expected tensor shape and dtype.
  2. The shape and dtypes are automatically checked and enforced by decorators. This will replace most of the functions in ct.sanity.

Changes of this PR

  • Python version support is expanded to Python 3.8 through Python 3.11. Python 3.7 is removed.
  • Introduces dependencies ivy, jaxtyping, and einops.
    • ivy: interoperate between numpy and torch. Shall only be used internally.
    • jaxtyping: type hinting for the library.
  • Importantly, torch remains an optional dependency.
  • Use ct.backend.get_backend() and ct.backend.set_backend() and to set and get the default backend for camtools, this change is global and will affect all subsequent invocations to camtools functions.
  • @ct.backend.with_native_backend will ensure that the function will use camtools' default backend's native tensor format (np.ndarray or torch.Tensor). It will also convert list inputs to native tensor format if the type hint is a tensor. See test/test_typing.py for examples.
  • @ct.typing.check_shape_and_dtype will enforce shape and dtype checks based on the type hints for tensors. See test/test_typing.py for examples. This check will eventually replace most of the functions in ct.sanity.

Examples

@ct.backend.with_native_backend
@ct.typing.check_shape_and_dtype
def add(
    x: Float[Tensor, "3"],
    y: Float[Tensor, "n 3"],
) -> Float[Tensor, "n 3"]:
    return x + y

@ct.backend.with_native_backend
def creation():
    zeros = ivy.zeros([2, 3])
    return zeros

@yxlao yxlao marked this pull request as ready for review July 1, 2024 10:28
@yxlao yxlao merged commit 34fa74a into main Jul 1, 2024
9 checks passed
@yxlao yxlao deleted the yixing/ivy branch July 1, 2024 10:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant