Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: automatic tensor backend and type checks #64

Merged
merged 32 commits into from
Jul 3, 2024

Conversation

yxlao
Copy link
Owner

@yxlao yxlao commented Jul 2, 2024

Design Goals

Goals for Multi-Backend Support

  1. All functions in camtools should be able to process and return both NumPy arrays and Torch tensors. This is achieved automatically by decorators.
  2. Torch should NOT be a compulsory dependency of camtools. The library can detect if Torch is installed and enable the corresponding backends.
  3. By default, the backend is set to NumPy. You can set the default backend to Torch by using set_backend().

Goals for Type Checking

  1. Functions in camtools should be type-hinted with the expected tensor shape and dtype.
  2. The shape and dtypes are automatically checked and enforced by decorators. This will replace most of the functions in ct.sanity.

Changes in This PR

  • New Dependencies: Introduces ivy, jaxtyping, and einops.
    • ivy: Interoperates between NumPy and Torch. Used internally only.
    • jaxtyping: Provides type hinting for the library.
  • Torch as an Optional Dependency: Torch remains optional.
  • Backend Management:
    • Use ct.backend.get_backend() and ct.backend.set_backend() to get and set the default backend for camtools. This change is global and affects all subsequent invocations of camtools functions.
    • @ct.backend.tensor_auto_backend tries to derive backend from input arguments. It also converts list inputs to native tensor format if the type hint is a tensor. See test/test_backend.py for examples.
    • @ct.backend.tensor_numpy_backend and @ct.backend.tensor_torch_backend wrap functions to enforce using NumPy and Torch backends by converting input tensors to NumPy or Torch, respectively.
  • Shape and Dtype Checking:
    • @ct.backend.tensor_type_check enforces shape and dtype checks based on the type hints for tensors. See test/test_backend.py for examples. This check will eventually replace most of the functions in ct.sanity.

Known Issues

Per our benchmark, the overhead of type-checking and backend conversion is small. However ivy introduces significant overheads. We may look for alternative methods, such as lightweight wrappers to wrap numpy/torch to replace ivy in the future.

@yxlao yxlao force-pushed the yixing/auto-backend branch from 5f5d42a to 8494b36 Compare July 3, 2024 07:50
@yxlao yxlao changed the base branch from main to flex-backend July 3, 2024 15:26
@yxlao yxlao changed the title feat: automatic backend selection and improved type checks feat: automatic tensor backend and type checks Jul 3, 2024
@yxlao yxlao merged commit b9d3d50 into flex-backend Jul 3, 2024
9 checks passed
@yxlao yxlao deleted the yixing/auto-backend branch July 3, 2024 15:37
yxlao added a commit that referenced this pull request Jul 9, 2024
This PR is a combination of:
* feat: automatic tensor backend and type checks (#64)
* perf: performance improvements for backend function wrappers (#66)
* feat: improved backend, union handling, tensor creation APIs (#68)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant