Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve how device switch is handled between the metric device and the input tensors device #3043

Merged
merged 24 commits into from
Aug 25, 2023

Conversation

MarcBresson
Copy link
Contributor

@MarcBresson MarcBresson commented Aug 23, 2023

@vfdev-5 I investigated on the weird code, and as it turns out, kernel could never be of size > 2 (it is only computed in _uniform() or _gaussian() which both output 2 dim tensors).

I wrote this little fix with a warning if the update tensors are not on the device device than the metric.

Do you know if calling .to(device) on a tensor that is already on the device will cause slow downs ?

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: metrics Metrics module label Aug 23, 2023
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 23, 2023

Thanks a lot the PR @MarcBresson !

I investigated on the weird code, and as it turns out, kernel could never be of size > 2 (it is only computed in _uniform() or _gaussian() which both output 2 dim tensors).

Yes, that's correct and as the code overwrites self._kernel:

            self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)

and usually the data has consistently the same device, the actual code is working.

Concerning this PR's code, I think it would be good to follow below logic:

kernel y, y_pred what to do
cpu cpu nothing
cuda cuda nothing
cpu cuda set kernel.to("cuda") for a temp var and use cuda for F.conv2d op
cuda cpu set y_pred, y to cuda

What do you think ?

Do you know if calling .to(device) on a tensor that is already on the device will cause slow downs ?

There was a slow down in early pytorch version, we should measure that now. Please use torch utils benchmark to get some numbers : https://github.com/vfdev-5/pth-inductor-dev/blob/eb01fa071a2337c7037e8a7e961b2147c5fc8b42/perf_flip.py#L52-L61

@MarcBresson
Copy link
Contributor Author

import torch
import torch.utils.benchmark as benchmark

def rand_to_device(shape: tuple, created_on_device, to_device):
    on_device = torch.rand(shape, device=created_on_device)

    if to_device is not None:
        t = on_device.to(device=to_device)

results = []
min_run_time = 5
shape = (12, 4, 256, 256)
available_devices = [torch.device("cuda:0"), torch.device("cpu")]
for from_device in available_devices:
    for to_device in available_devices + [None]:
        print(f"{from_device} to {to_device} measurements")
        results.append(
            benchmark.Timer(
                stmt=f"fn({shape}, created_on_device, to_device)",
                globals={
                    "fn": rand_to_device,
                    "created_on_device": from_device,
                    "to_device": to_device,
                },
                num_threads=torch.get_num_threads(),
                label=f"{from_device} to {to_device} measurements"
            ).blocked_autorange(min_run_time=min_run_time)
        )

I had an error when calling compare = benchmark.Compare(results); compare.print() so I just printed out the objects:

results
>>> cuda:0 to cuda:0 measurements
  Median: 113.61 us
  IQR:    0.70 us (113.34 to 114.04)
  430 measurements, 100 runs per measurement, 6 threads
>>> cuda:0 to cpu measurements
  Median: 4.95 ms
  IQR:    0.14 ms (4.89 to 5.03)
  11 measurements, 100 runs per measurement, 6 threads
>>> cuda:0 to None measurements
  Median: 110.68 us
  IQR:    0.30 us (110.45 to 110.75)
  46 measurements, 1000 runs per measurement, 6 threads
>>> cpu to cuda:0 measurements
  Median: 23.79 ms
  IQR:    6.15 ms (22.10 to 28.24)
  20 measurements, 10 runs per measurement, 6 threads
  WARNING: Interquartile range is 25.8% of the median measurement.
           This suggests significant environmental influence.
>>> cpu to cpu measurements
  Median: 21.80 ms
  IQR:    3.26 ms (21.11 to 24.37)
  21 measurements, 10 runs per measurement, 6 threads
  WARNING: Interquartile range is 14.9% of the median measurement.
           This could indicate system fluctuation.
>>> cpu to None measurements
  Median: 20.86 ms
  IQR:    5.49 ms (19.88 to 25.37)
  21 measurements, 10 runs per measurement, 6 threads
  WARNING: Interquartile range is 26.3% of the median measurement.
           This suggests significant environmental influence.

It seems like there is a really slight slow down (we must compare cuda:0 to cuda:0 with cuda:0 to None).

@MarcBresson
Copy link
Contributor Author

I think that is a great suggestion, always performing the computation on GPU if either one of the kernel or input tensor in on it.

If either one of the metric device or the update input device
is a GPU, this commit will put the other one on GPU.
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 24, 2023

Here is how you can make it. I changed a bit the perf code to measure .to op only.

import torch
import torch.utils.benchmark as benchmark


results = []
min_run_time = 5
shape = (12, 4, 256, 256)


available_devices = [torch.device("cuda"), torch.device("cpu")]
for shape in [(12, 4, 256, 256), (8, 3, 512, 512)]:
    for from_device in available_devices:
        data = torch.rand(shape, device=from_device)
        for to_device in available_devices:
            print(f"{shape} -> {from_device} to {to_device} measurements")
            results.append(
                benchmark.Timer(
                    stmt=f"data.to(to_device, non_blocking=False)",
                    globals={
                        "data": data,
                        "to_device": to_device,
                    },
                    description=f"{from_device} to {to_device}",
                    num_threads=torch.get_num_threads(),
                    label="Device to device",
                    sub_label=f"{tuple(shape)}",
                ).blocked_autorange(min_run_time=min_run_time)
            )

compare = benchmark.Compare(results)
compare.print()

On my infra it gives:

[---------------------------------- Device to device ---------------------------------]                                                        │| Processes:                                                                            |
                         |  cuda to cuda  |  cuda to cpu  |  cpu to cuda  |  cpu to cpu                                                        │|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
6 threads: ----------------------------------------------------------------------------                                                        │|        ID   ID                                                             Usage      |
      (12, 4, 256, 256)  |     844.8      |   1046673.6   |   1165780.2   |    731.3                                                           │|=======================================================================================|
      (8, 3, 512, 512)   |     868.6      |   2018511.1   |   2439948.9   |    721.6                                                           │|  No running processes found                                                           |
                                                                                                                                               │+---------------------------------------------------------------------------------------+
Times are in nanoseconds (ns). 

I agree that it is not a big deal but we also have to think that in the real application .to in the code can be repeated N times, where N depends on the dataset times number of epochs, thus overall slowdown may be visible.

@MarcBresson
Copy link
Contributor Author

MarcBresson commented Aug 24, 2023

You should include a None element in the to_device array. That has the effect of just creating the tensor and not trying to move it on a device. This is the thing that allows us to see if moving a tensor to a device it is already takes time or not.

With this None element, I have:

[-------------------------------------------------- Device to device --------------------------------------------------]
                         |  cuda          |  cuda to cuda  |  cuda to cpu  |  cpu          |  cpu to cuda  |  cpu to cpu
6 threads: ------------------------------------------------------------------------------  -----------------------------
      (12, 4, 256, 256)  |     250.1      |     716.2      |    4818370.0  |     294.5     |   3130725.0   |    668.6   
      (8, 3, 512, 512)   |     255.5      |     813.2      |   11758880.0  |     260.3     |   6170050.0   |    668.9   

Times are in nanoseconds (ns).

We can see that creating an element on cuda then moving it to the same cuda device takes ~800ns while not trying to move it only takes ~250ns. The figures there are strangely very different from what I had on my first run #3043 (comment)...

--- EDIT ---
my mistake, i havn't seen the modification of the function that we benchmark. Now I wonder what .to(None) does ahahah.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 24, 2023

Now I wonder what .to(None) does ahahah.

according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch.Tensor.to
it should go as

torch.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) → Tensor

and probably do nothing. So, we are observing just calling a python function for 250 ns (which is reasonable)

@MarcBresson
Copy link
Contributor Author

(there are bugs, I am writing test and will correct them)

The comparison with self._device was not possible because it
can be created with `torch.device("cuda")` which is not equal
to `torch.device("cuda:0")` which is the device of a tensor
created with `torch.device("cuda")`. This change will have
a bigger performance hit when self._kernel is not on the same
device as y_pred as it will need to be moved onto y_pred's
device every time update() is called.
@MarcBresson
Copy link
Contributor Author

I fixed everything. One thing to note is the (assumed) performance hit when self._kernel is not on the same device as y_pred as it will need to be moved onto y_pred's device every time update() is called.

ignite/metrics/ssim.py Outdated Show resolved Hide resolved
ignite/metrics/ssim.py Outdated Show resolved Hide resolved
@MarcBresson MarcBresson changed the title refactor: remove outdated code and issue a warning if two tensors are on separate devices. feat: improve how device switch is handled between the metric device and the input tensors device Aug 24, 2023
ignite/metrics/ssim.py Outdated Show resolved Hide resolved
ignite/metrics/ssim.py Outdated Show resolved Hide resolved
@MarcBresson
Copy link
Contributor Author

MarcBresson commented Aug 25, 2023

I am writing new tests for the variable channel size, will push soon will all the changes that you suggested

@MarcBresson
Copy link
Contributor Author

oh no conflicts, what have I done ?

ignite/metrics/ssim.py Outdated Show resolved Hide resolved
ignite/metrics/ssim.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @MarcBresson !

@vfdev-5 vfdev-5 merged commit 11a1fba into pytorch:master Aug 25, 2023
13 of 17 checks passed
@MarcBresson MarcBresson deleted the refactor-_update branch August 25, 2023 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: metrics Metrics module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants