Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: performance improvements for backend function wrappers #66

Merged
merged 37 commits into from
Jul 9, 2024

Conversation

yxlao
Copy link
Owner

@yxlao yxlao commented Jul 9, 2024

Key Changes

  • The notion of default backends is removed. Backends are default to numpy. If the input arguments are all numpy arrays or contain no arrays, the internal computation and output will be numpy arrays. If the input arguments are all torch tensors, the internal computation and output will still be torch tensors. ct.backend.set_backend() and ct.backend.get_backend() are also removed.
  • All you need is @ct.backend.tensor_to_auto_backend, @ct.backend.tensor_to_numpy_backend, or @ct.backend.tensor_to_torch_backend to convert the input arguments to the desired backend.
  • You can use ct.backend.enable_tensor_check() and ct.backend.disable_tensor_check() to enable or disable the tensor type checking (for dtype and shape) globally. By default, the tensor type checking is enabled.
  • The wrapper functions are carefully optimized to minimize the overhead of backend conversion and type checking
    1. Signatures are resolved once per function rather than on every call
    2. Critical functions are cached to reduce overhead
    3. Optimize signature argument binding mechanism
    4. Fuse the type checker and backend converter into a single function

Additional demo and benchmark code

demo.py

The @ct.backend.tensor_to_numpy_backend wrapper converts input np.ndarray, torch.Tensor, list, or tuple to np.ndarray before passing them to the function. It also checks for the dtype and shape of the input tensors.

import numpy as np
from jaxtyping import Float
import camtools as ct
from camtools.backend import Tensor, torch


@ct.backend.tensor_to_numpy_backend
def sum_xyz(
    x: Float[Tensor, "3"],
    y: Float[Tensor, "3"] = (2.0, 2.0, 2.0),
    z: Float[Tensor, "3"] = (3.0, 3.0, 3.0),
):
    assert isinstance(x, np.ndarray)
    assert isinstance(y, np.ndarray)
    assert isinstance(z, np.ndarray)
    return x + y + z


def main():
    x = np.array([1.0, 1.0, 1.0])
    y = torch.tensor([5.0, 5.0, 5.0])
    result = sum_xyz(x, y=y)
    expected_result = np.array([9.0, 9.0, 9.0])
    assert np.allclose(result, expected_result)
    print("Tests passed!")


if __name__ == "__main__":
    main()

bench_w_decorators.py

import numpy as np
from jaxtyping import Float

import camtools as ct
import cProfile
from camtools.backend import Tensor, torch

_array_repeat = 1000


def workload(x, y):
    x = np.repeat(x, _array_repeat)
    y = np.repeat(y, _array_repeat)
    return np.dot(x, y)


@ct.backend.tensor_to_numpy_backend
def _dot_with_decorator(
    x: Float[Tensor, "..."],
    y: Float[Tensor, "..."],
):
    return workload(x, y)


def run_with_decorator():
    x = torch.tensor([1.0, 2.0, 3.0])
    y = torch.tensor([4.0, 5.0, 6.0])
    result = _dot_with_decorator(x, y)
    assert result == 32.0 * _array_repeat


def main():
    # Warmup
    run_with_decorator()

    # Profile
    profiler = cProfile.Profile()
    profiler.enable()
    run_with_decorator()
    profiler.disable()
    profiler.dump_stats("w_decorator.prof")


if __name__ == "__main__":
    main()

You may then visualize the profiling results using snakeviz:

pip install snakeviz
snakeviz w_decorator.prof

bench_wo_decorators.py

import numpy as np
from jaxtyping import Float

import camtools as ct
import cProfile
from camtools.backend import Tensor, torch

_array_repeat = 1000


def workload(x, y):
    x = np.repeat(x, _array_repeat)
    y = np.repeat(y, _array_repeat)
    return np.dot(x, y)


def dot(
    x: Float[Tensor, "..."],
    y: Float[Tensor, "..."],
):
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    return workload(x, y)


def run_without_decorator():
    x = torch.tensor([1.0, 2.0, 3.0])
    y = torch.tensor([4.0, 5.0, 6.0])
    result = dot(x, y)
    assert result == 32.0 * _array_repeat


def main():
    # Warmup
    run_without_decorator()

    # Profile
    profiler = cProfile.Profile()
    profiler.enable()
    run_without_decorator()
    profiler.disable()
    profiler.dump_stats("wo_decorator.prof")


if __name__ == "__main__":
    main()

You may then visualize the profiling results using snakeviz:

pip install snakeviz
snakeviz wo_decorator.prof

test_backend_bench.py

import numpy as np
import camtools as ct
import pytest
from jaxtyping import Float

from camtools.backend import Tensor, ivy, is_torch_available, torch

_workload_repeat = 100


@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available")
def test_concat_torch_to_numpy_manual(benchmark):

    def concat(
        x: Float[Tensor, "..."],
        y: Float[Tensor, "..."],
    ):
        x = x.cpu().numpy()
        y = y.cpu().numpy()
        for _ in range(_workload_repeat):
            np.concatenate([x, y], axis=0)
        return np.concatenate([x, y], axis=0)

    x = torch.tensor([1.0, 2.0, 3.0])
    y = torch.tensor([4.0, 5.0, 6.0])
    result = benchmark(concat, x, y)
    assert isinstance(result, np.ndarray)
    assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))


@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available")
def test_concat_torch_to_numpy_auto(benchmark):

    @ct.backend.tensor_to_numpy_backend
    def concat(
        x: Float[Tensor, "..."],
        y: Float[Tensor, "..."],
    ):
        for _ in range(_workload_repeat):
            np.concatenate([x, y], axis=0)
        return np.concatenate([x, y], axis=0)

    x = torch.tensor([1.0, 2.0, 3.0])
    y = torch.tensor([4.0, 5.0, 6.0])
    result = benchmark(concat, x, y)
    assert isinstance(result, np.ndarray)
    assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))

You may run the benchmark with:

pytest test/test_backend_bench.py -s

@yxlao yxlao merged commit 7216a8e into flex-backend Jul 9, 2024
9 checks passed
@yxlao yxlao deleted the yixing/remove-ivy branch July 9, 2024 07:06
yxlao added a commit that referenced this pull request Jul 9, 2024
This PR is a combination of:
* feat: automatic tensor backend and type checks (#64)
* perf: performance improvements for backend function wrappers (#66)
* feat: improved backend, union handling, tensor creation APIs (#68)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant