Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: replace _uniform method to remove iteration on tensor #3042

Merged
merged 1 commit into from
Aug 24, 2023

Conversation

MarcBresson
Copy link
Contributor

@MarcBresson MarcBresson commented Aug 23, 2023

This new version is much quicker (granted, it will not save a lot of absolute time).

It avoids enumerating on a tensor, which is always slow.

def old_uniform(kernel_size: int):
    max, min = 2.5, -2.5
    ksize_half = (kernel_size - 1) * 0.5
    kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    for i, j in enumerate(kernel):
        if min <= j <= max:
            kernel[i] = 1 / (max - min)
        else:
            kernel[i] = 0

    return kernel.unsqueeze(dim=0)

def new_uniform(kernel_size):
    kernel = torch.zeros(kernel_size)

    start_uniform_index = max(kernel_size // 2 - 2, 0)
    end_uniform_index = min(kernel_size // 2 + 3, kernel_size)

    min_, max_  = -2.5, 2.5
    kernel[start_uniform_index:end_uniform_index] = 1 / (max_ - min_)

    return kernel.unsqueeze(dim=0)

Performance comparison

%timeit old_uniform(11)
>>> 354 µs ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit new_uniform(11)
>>> 13.6 µs ± 303 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit old_uniform(3)
>>> 123 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%timeit new_uniform(3)
>>> 11.6 µs ± 1.13 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Check for equality

for kernel_size in range(1, 101, 2):
    torch.testing.assert_close(old_uniform(kernel_size), new_uniform(kernel_size))

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: metrics Metrics module label Aug 23, 2023
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 23, 2023

Nice improvements, @MarcBresson !

For equality tests I would use torch.testing.assert_close(old_uniform(kernel_size), new_uniform(kernel_size)).

I wonder how to see that

start_uniform_index = max(kernel_size // 2 - 2, 0)
end_uniform_index = min(kernel_size // 2 + 3, kernel_size)
kernel[start_uniform_index:end_uniform_index] = 1 / (max_ - min_)

is equivalent to

        if min <= j <= max:
            kernel[i] = 1 / (max - min)

?

@MarcBresson
Copy link
Contributor Author

It was hard to wrap my head around the possible decisions that led to this code.

Basically, the former code was creating a tensor that went from -kernel_size // 2 to kernel_size // 2. Then all the values of this tensor that were < to -2.5 or > to 2.5 were replaced by 0, and the other were set to 1 / (max - min).

The new code just put 0 everywhere then compute the indices that should be set to 1 / (max - min).

The only difference between the two codes is when the kernel size is an even number, but this raise an error beforehand.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcBresson thanks a lot for the perf improvement!
LGTM

@vfdev-5 vfdev-5 merged commit 178d82c into pytorch:master Aug 24, 2023
16 of 18 checks passed
@MarcBresson MarcBresson deleted the refactor-ssim-_uniform branch August 24, 2023 08:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: metrics Metrics module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants