Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

42 multi gpu #49

Merged
merged 7 commits into from
Feb 7, 2020
Merged

42 multi gpu #49

merged 7 commits into from
Feb 7, 2020

Conversation

ericspod
Copy link
Member

@ericspod ericspod commented Feb 4, 2020

This adds a function for creating a supervised learner and evaluator in the same style as Ignite. An example notebook and unit tests are provided. The tests will not execute a multi-GPU test correctly if the host has only one GPU, a way to emulate multiple GPUs would be helpful.

@wyli
Copy link
Contributor

wyli commented Feb 4, 2020

since the gpu-based CI is in place, could remove the non-gpu CI previously configured here:

- name: Test and coverage
run: |
./runtests.sh --coverage

@anfeng anfeng closed this Feb 4, 2020
@anfeng anfeng reopened this Feb 4, 2020

if devices is None:
devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]
elif len(devices) == 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should change this to "if" instead of "elif".

Consider the case that a machine don't have GPU, and devices was None in the call. Your statement in L52 will return an empty array. This will cause problem at L59 for "devices[0]"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @anfeng and @ericspod ,

As this file is "multi_gpu_supervised_trainer.py", I think we should only support Multi-GPU here.
Just assert something like:

if devices is None:
    devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]
assert len(devices) > 1, 'must have more 1 GPU devices.'
... ...

What do you think?
Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where there are not GPUs but they were requested, the list will be empty and the failure will occur on line 59. Leaving an error to propagate elsewhere is sort of the Pythonic way but instead we should raise an error. Assert isn't appropriate here as we're checking input essentially rather than an internal property our own logic should enforce if correct. I've committed changes to reflect this. I've also added code to the multi-GPU unit test to suppress the warning about GPU memory imbalance, I'd say unit tests should be silent unless they fail but maybe the warning should still be allowed or at least logged.

@Nic-Ma It's harmless to support only 1 GPU, if people have code parametrized by the number of GPUs to use they'd have to choose which function to call based on a check if that was 1 or not, it's easier if this isn't restricted in that way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ericspod ,

Thanks for your explanation.
If you want to suppport both CPU and GPU/GPUs logics here, why not rename the function to some general-purpose trainer? Maybe:

def create_supervised_trainer(devices):
    if devices is None:
        devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]
    if len(devices) == 0:
        devices = [torch.device("cpu")]
    ... ...

And about your new util function get_devices_spec(), I don't think it's a good idea to use empty list as parameter for CPU device, it's confusing.Maybe add another flag for CPU/GPU directly?
Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nic-Ma What do you suggest for parameter instead?

We can use strings like "multi-gpu", "gpu", "cpu" but if the user passes "cuda:0" what do we do?. We can enforce only using our defined names but that seems too restrictive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @madil90 and @ericspod ,

Thanks for your comments, I prefer to this strategy:

def create_supervised_trainer(devices=None):
    if devices is None:
        devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]
        if len(devices) == 0:
            devices = [torch.device("cpu")]
    else:
        # use devices parameter directly.
    ... ...
# use cases:
trainer = create_supervised_trainer()  # automatically select devices
trainer = create_supervised_trainer(devices=[torch.device("cpu")])
trainer = create_supervised_trainer(devices=[torch.device("cuda:0")])
trainer = create_supervised_trainer(devices=[torch.device("cuda:0"), torch.device("cuda:1")])
  1. If user pass something through "devices" parameter, let's use it directly.
    We can add some sanity check in later version.
  2. If no devices provided, we try to use all GPUs first, if no GPU found, use CPU instead.

What do you think?
Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nic-Ma A little verbose but seems fine. However, the last one becomes redundant. PyTorch will create a parallel context on all GPUs. The responsibility is on user to select GPUs through CUDA_VISIBLE_DEVICES.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @madil90 , Yes, thanks for your reminder.
I just want to share some alternative proposals.
For the MVP version, I am OK to use @ericspod 's method for CPU device(empty arrary).

Hi @ericspod , I think you can review Adil's PR for your branch and make the final solution.
Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the name, I used a different name to not clash with the one from Ignite, it should be obvious which is being used when looking at source code.

Choosing to use CPU computation silently when no GPU is present is going to cause people to use CPU when they didn't expect it. There should be a loud and clear error when something requested isn't possible.


if devices is None:
devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]
elif len(devices) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anfeng comment for L53 applies here too. Perhaps we should move this logic to some util function?

Copy link
Contributor

@Nic-Ma Nic-Ma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments inline.
Thanks.

if devices is None:
devices = [torch.device('cuda:%i' % d) for d in range(torch.cuda.device_count())]

if len(devices) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no GPU is found, we should default to CPU.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned elsewhere, defaulting to CPU like that will cause silent errors when people expect to use GPUs. If people want GPUs they should get a loud and clear error that they can't get them, otherwise they'll think everything is find just super slow.

@ericspod ericspod requested a review from Nic-Ma February 6, 2020 15:19
raise ValueError("No GPU devices available")

elif len(devices) == 0:
devices = [torch.device("cpu")]
Copy link
Contributor

@Nic-Ma Nic-Ma Feb 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just suggest to print a warning here if use CPU instead.
Because this code file is for "multi_gpu_trainer", what do you think?
People may don't know that "devices = empty list" is "CPU device".
Others look good to me.
Thanks.

@wyli wyli merged commit 962cb11 into master Feb 7, 2020
@wyli wyli deleted the 42-multi-gpu branch April 6, 2020 13:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants