Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add c10d dynamic loading mechanism and unit test #28068

Closed
wants to merge 1 commit into from

Conversation

ftian1
Copy link
Contributor

@ftian1 ftian1 commented Oct 16, 2019

The original behavior of pytorch c10d only supports built-in c10d backends, such as
nccl/gloo/mpi. This patch is used to extend the c10d capability to support dynamically
loading 3rd party communication libraries which are derived from ProcessGroup base class.

related RFC is in: #27955

Through this way, user just need specify a 3rd party c10d backend name when invoking
torch.distributed.init_process_group(). The proposed logic will try to load corresponding
c10d backend cpp extension automatically. as for how to develop a new 3rd party c10d backend
through cpp extension, pls refer to test/cpp_extensions/cpp_c10d_extension.cpp

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this!!! This will be very useful for other backends as well.

test/cpp_extensions/cpp_c10d_extension.cpp Outdated Show resolved Hide resolved
torch/distributed/distributed_c10d.py Outdated Show resolved Hide resolved
test/test_c10d.py Outdated Show resolved Hide resolved
torch/distributed/distributed_c10d.py Outdated Show resolved Hide resolved
@mshiryaev
Copy link

mshiryaev commented Oct 17, 2019

@ftian1, could you please also include C10D related headers into PyTorch installation to have ability to build external ProcessGroup which depends from these headers?

diff --git a/setup.py b/setup.py
index 7ada277..68c6cb1 100644
--- a/setup.py
+++ b/setup.py
@@ -813,6 +813,7 @@ if __name__ == '__main__':
                 'include/c10/cuda/impl/*.h',
                 'include/c10/hip/*.h',
                 'include/c10/hip/impl/*.h',
+                'include/c10d/*.hpp',
                 'include/caffe2/**/*.h',
                 'include/torch/*.h',
                 'include/torch/csrc/*.h', 

@mrshenli mrshenli added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 17, 2019
@ftian1
Copy link
Contributor Author

ftian1 commented Oct 17, 2019

@mrshenli @mshiryaev pls help to check the patch again.

torch/distributed/distributed_c10d.py Outdated Show resolved Hide resolved
docs/source/distributed.rst Outdated Show resolved Hide resolved
docs/source/distributed.rst Outdated Show resolved Hide resolved
docs/source/distributed.rst Outdated Show resolved Hide resolved
docs/source/distributed.rst Outdated Show resolved Hide resolved
test/test_c10d.py Outdated Show resolved Hide resolved
test/cpp_extensions/cpp_c10d_extension.cpp Outdated Show resolved Hide resolved
test/cpp_extensions/cpp_c10d_extension.hpp Show resolved Hide resolved
test/test_c10d.py Outdated Show resolved Hide resolved
@ftian1
Copy link
Contributor Author

ftian1 commented Oct 18, 2019

@mshiryaev @jgong5 @mrshenli I updated the patch according to you guys comments. thanks.

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failures are real:

Oct 18 04:57:19 test_broadcast_coalesced_gloo_cuda (__main__.CommTest) ... THCudaCheck FAIL file=/var/lib/jenkins/workspace/aten/src/THC/THCGeneral.cpp line=54 error=3 : initialization error
Oct 18 04:57:19 THCudaCheck FAIL file=/var/lib/jenkins/workspace/aten/src/THC/THCGeneral.cpp line=54 error=3 : initialization error
Oct 18 04:57:19 ERROR:root:Caught exception: 
Oct 18 04:57:19 Traceback (most recent call last):
Oct 18 04:57:19   File "/var/lib/jenkins/workspace/test/common_distributed.py", line 133, in wrapper
Oct 18 04:57:19     fn(self)
Oct 18 04:57:19   File "/var/lib/jenkins/workspace/test/common_distributed.py", line 35, in wrapper
Oct 18 04:57:19     return func(*args, **kwargs)
Oct 18 04:57:19   File "test_c10d.py", line 3198, in test_broadcast_coalesced_gloo_cuda
Oct 18 04:57:19     self._test_broadcast_coalesced(process_group, device)
Oct 18 04:57:19   File "test_c10d.py", line 3161, in _test_broadcast_coalesced
Oct 18 04:57:19     target = torch.arange(60, dtype=half, device=device).chunk(5)
Oct 18 04:57:19   File "/opt/conda/lib/python3.6/site-packages/torch/cuda/__init__.py", line 198, in _lazy_init
Oct 18 04:57:19     torch._C._cuda_init()
Oct 18 04:57:19 RuntimeError: cuda runtime error (3) : initialization error at /var/lib/jenkins/workspace/aten/src/THC/THCGeneral.cpp:54

I asked @gchanan offline regarding the above error. One potential cause is that the new cpp_extension code did not properly initialize cuda or triggered some error, which could lead to errors for subsequent tests.

Oct 18 04:57:19 test_collective_apis (__main__.BackendLoadingTest) ... ok
Oct 18 04:57:19 test_broadcast_coalesced_gloo_cpu (__main__.CommTest) ... ok
Oct 18 04:57:19 test_broadcast_coalesced_gloo_cuda (__main__.CommTest) ... THCudaCheck FAIL file=/var/lib/jenkins/workspace/aten/src/THC/THCGeneral.cpp line=54 error=3 : initialization error

To debug this, setting CUDA_LAUNCH_BLOCKING and cudaGetLastError() might help.

docs/source/distributed.rst Outdated Show resolved Hide resolved
docs/source/distributed.rst Outdated Show resolved Hide resolved
setup.py Show resolved Hide resolved
test/cpp_extensions/cpp_c10d_extension.hpp Outdated Show resolved Hide resolved
@mrshenli
Copy link
Contributor

@yf225 do you know what is the appropriate way to initialize cuda in cpp extensions?

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 18, 2019
@ftian1
Copy link
Contributor Author

ftian1 commented Oct 23, 2019

@mrshenli @yf225
I reproduced the "RuntimeError: cuda runtime error (3) : initialization error" error. and in fact it can be triggered by adding "from torch.utils.cpp_extension import load" at original test_c10d.py without any my changes. It's very weird and solutions queried from internet is not helpful...

I will continue to look into it, but if you guys have suggestions I would be appreciate for that.

@yf225
Copy link
Contributor

yf225 commented Oct 23, 2019

@ftian1 I feel that it might be worthwhile to look at the docs for torch.utils.cpp_extension.load

def load(name,
sources,
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
verbose=False,
with_cuda=None,
is_python_module=True):
'''
Loads a PyTorch C++ extension just-in-time (JIT).
To load an extension, a Ninja build file is emitted, which is used to
compile the given sources into a dynamic library. This library is
subsequently loaded into the current Python process as a module and
returned from this function, ready for use.
By default, the directory to which the build file is emitted and the
resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
``<tmp>`` is the temporary folder on the current platform and ``<name>``
the name of the extension. This location can be overridden in two ways.
First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
into subfolders of this directory. Second, if the ``build_directory``
argument to this function is supplied, it overrides the entire path, i.e.
the library will be compiled into that folder directly.
To compile the sources, the default system compiler (``c++``) is used,
which can be overridden by setting the ``CXX`` environment variable. To pass
additional arguments to the compilation process, ``extra_cflags`` or
``extra_ldflags`` can be provided. For example, to compile your extension
with optimizations, pass ``extra_cflags=['-O3']``. You can also use
``extra_cflags`` to pass further include directories.
CUDA support with mixed compilation is provided. Simply pass CUDA source
files (``.cu`` or ``.cuh``) along with other sources. Such files will be
detected and compiled with nvcc rather than the C++ compiler. This includes
passing the CUDA lib64 directory as a library directory, and linking
``cudart``. You can pass additional flags to nvcc via
``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
heuristics for finding the CUDA install directory are used, which usually
work fine. If not, setting the ``CUDA_HOME`` environment variable is the
safest option.
Args:
name: The name of the extension to build. This MUST be the same as the
name of the pybind11 module!
sources: A list of relative or absolute paths to C++ source files.
extra_cflags: optional list of compiler flags to forward to the build.
extra_cuda_cflags: optional list of compiler flags to forward to nvcc
when building CUDA sources.
extra_ldflags: optional list of linker flags to forward to the build.
extra_include_paths: optional list of include directories to forward
to the build.
build_directory: optional path to use as build workspace.
verbose: If ``True``, turns on verbose logging of load steps.
with_cuda: Determines whether CUDA headers and libraries are added to
the build. If set to ``None`` (default), this value is
automatically determined based on the existence of ``.cu`` or
``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
and libraries to be included.
is_python_module: If ``True`` (default), imports the produced shared
library as a Python module. If ``False``, loads it into the process
as a plain dynamic library.
Returns:
If ``is_python_module`` is ``True``, returns the loaded PyTorch
extension as a Python module. If ``is_python_module`` is ``False``
returns nothing (the shared library is loaded into the process as a side
effect).
Example:
>>> from torch.utils.cpp_extension import load
>>> module = load(
name='extension',
sources=['extension.cpp', 'extension_kernel.cu'],
extra_cflags=['-O2'],
verbose=True)
'''
regarding CUDA-related usage, and also examples in test/test_cpp_extensions.py regarding CUDA-related compilations.

@ftian1
Copy link
Contributor Author

ftian1 commented Oct 30, 2019

@mrshenli @yf225
I followed your suggest and dug into why test/test_cpp_extensions.py is ok on "from cpp_extension import load" but test/test_c10d.py is not ok if we add same line to test_c10d.py.

the root cause is because classes in test_c10d.py is derived from MultiProcessTestCase and invoke self._fork_processes() to start new processes. according to the search result in web, we have to use spawn way rather fork to generate new process and avoid cuda_init() error. But the embarrass thing is _spawn_processes() only works with python3. python 2 will assert at _spawn_processes() of test/common_distributed.py.

could you give some comments on this? shall we update all fork() to spawn() and only allow py3 to trigger c10d test by adding it to py2 blacklist?

@pietern
Copy link
Contributor

pietern commented Oct 30, 2019

@ftian1 Would it be possible to delay calling import load to when you're actually running the test, instead of calling it when the file is loaded? Then it won't be loaded until after the processes have been forked/spawned and you shouldn't run into this issue.

We'll keep Python 2 support around until we have officially deprecated it on January 1st (see #23795).

@ddkalamk
Copy link
Contributor

@ftian1 Could you please also add a new init_method that explicitly allow backend to discover rank and world_size internally inside custom backend, something similar to what MPI backend does? I propose adding a new "auto" init_method that allows this like I am doing here in ddkalamk@84fddd8

@kostmo
Copy link
Member

kostmo commented Oct 31, 2019

💊 CircleCI build failures summary and remediations

As of commit afe3c0b (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (1/1)

Step: "Build" (full log | pattern match details) <confirmed not flaky by 2 failures>

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/pytorch-build-params.yml 
Auto-merging .circleci/verbatim-sources/pytorch-build-params.yml 
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/job-specs-custom.yml 
Auto-merging .circleci/verbatim-sources/job-specs-custom.yml 
CONFLICT (add/add): Merge conflict in .circleci/scripts/should_run_job.py 
Auto-merging .circleci/scripts/should_run_job.py 
CONFLICT (add/add): Merge conflict in .circleci/scripts/cpp_doc_push_script.sh 
Auto-merging .circleci/scripts/cpp_doc_push_script.sh 
CONFLICT (add/add): Merge conflict in .circleci/config.yml 
Auto-merging .circleci/config.yml 
Automatic merge failed; fix conflicts and then commit the result. 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 47 times.

@ftian1
Copy link
Contributor Author

ftian1 commented Nov 1, 2019

@mrshenli Thanks for @pietern suggestions, I modified the code to make "import cpp_extension" dely happens on a separate test process. it works on my local. The latest preci failures, such as "/temp/test_c10d file existed" or "nija not exist", look like test env issue rather my code bug. could you pls help review again? sorry for the so long PR thread...

@mrshenli
Copy link
Contributor

mrshenli commented Nov 1, 2019

Given the following failure, it seems that the load has to be called in the parent process, otherwise all children processes will try to create the same file. If the problem is spawn vs fork, could you please try moving the test_collective_apis test to test_c10d_spawn.py?

Nov 01 10:52:23 ERROR:root:Caught exception: 
Nov 01 10:52:23 Traceback (most recent call last):
Nov 01 10:52:23   File "/var/lib/jenkins/workspace/test/common_distributed.py", line 133, in wrapper
Nov 01 10:52:23     fn(self)
Nov 01 10:52:23   File "test_c10d.py", line 3257, in test_collective_apis
Nov 01 10:52:23     self._load_c10d_test_backend()
Nov 01 10:52:23   File "test_c10d.py", line 3253, in _load_c10d_test_backend
Nov 01 10:52:23     sources=[src],
Nov 01 10:52:23   File "/opt/conda/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 658, in load
Nov 01 10:52:23     build_directory or _get_build_directory(name, verbose),
Nov 01 10:52:23   File "/opt/conda/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1015, in _get_build_directory
Nov 01 10:52:23     os.makedirs(build_directory)
Nov 01 10:52:23   File "/opt/conda/lib/python3.6/os.py", line 220, in makedirs
Nov 01 10:52:23     mkdir(name, mode)
Nov 01 10:52:23 FileExistsError: [Errno 17] File exists: '/tmp/torch_extensions/c10dtest'

@mrshenli
Copy link
Contributor

mrshenli commented Nov 1, 2019

@yf225

I wonder if the the above failure is something that we need to fix. Do we actually need some uuid/temp dir/name for extensions loading? Otherwise, it seems would fail if multiple process try to load the same extension?

test/cpp_extensions/cpp_c10d_extension.cpp Outdated Show resolved Hide resolved
docs/source/distributed.rst Outdated Show resolved Hide resolved
test/cpp_extensions/cpp_c10d_extension.hpp Outdated Show resolved Hide resolved
test/cpp_extensions/cpp_c10d_extension.hpp Outdated Show resolved Hide resolved
torch/distributed/distributed_c10d.py Outdated Show resolved Hide resolved
@yf225
Copy link
Contributor

yf225 commented Nov 12, 2019

@yf225

I wonder if the the above failure is something that we need to fix. Do we actually need some uuid/temp dir/name for extensions loading? Otherwise, it seems would fail if multiple process try to load the same extension?

@mrshenli Yes I think we should use temp names for extensions loading. I will file an issue for it after reproducing it with a smaller example.

@gottbrath
Copy link
Contributor

@pritamdamania87 can you review and approve this or provide clear guidance on what is needed to have this approved.

@ftian1
Copy link
Contributor Author

ftian1 commented Feb 25, 2020

@VitalyFedyunin hello, Would you pls help to review the PR?

@VitalyFedyunin
Copy link
Contributor

I will pass it to @mrshenli as he is most experienced in distributed

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize that I dropped the ball on this, and sorry about the long delay on this PR. Let's expedite landing processing for this.

The current version looks great to me. I have three main comments

  1. ideally we should have a well-designed and stable PG interface before we landing this PR. But as I am not sure when that will come, let's don't block landing this PR due to that. If we need to change the PG interface in the future, we will modify this as well. However, if your extension lives out of PyTorch repo, it means there might be conflicts. How bad will that be if we break your extension? @ftian1

Also, @agolynski please drop notes if you have comments on this.

  1. Looks like the following will be THE interface for PG constructor. This might not be sufficient for future PG implementation if they require additional configurations. But we should be able address those by adding sth like backend_options later when that requirement emerges. As for now, is there any reason for adding the group_name here? NCCL and GLOO does not require a name but does require a timeout. Shall we at least pass the timeout to the PG as well so that implementations can respect timeout if they would like to?
            pg = getattr(Backend, backend.upper())(
                prefix_store,
                rank,
                world_size,
                group_name)
  1. Looks like the test require Ninja. Could you please explain why do we need that?

Hey @yf225, do we have ninja installed in all our CI test envs?


When manually importing this backend and invoking :func:`torch.distributed.init_process_group`
with the corresponding backend name, the `torch.distributed` package runs on the new backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agolynski since we open this up, we need to have some text clearly describing what APIs the 3rd-party backend should implement. Could you please prepare a PR for that after we land this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

const std::string& groupName);

static void ProcessGroupTestConstructor() __attribute__((constructor)) {
py::object module = py::module::import("torch.distributed");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is holding GIL for these function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to pybind11 doc, when calling a C++ function from Python, the GIL is always held.

@@ -30,6 +31,12 @@

skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

CPP_EXTENSIONS_WARNING = """
Ninja (https://ninja-build.org) must be available to run C++ extensions tests,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why do we need ninja for this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because torch c++ extension build needs Ninja. it's copied from test/run_test.py (you can see same thing in run_test.py)

return std::make_shared<ProcessGroupTest::WorkTest>();
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce_coalesced(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agolynski do you have plans to make this allreduce_base as well? Let's not further delay this PR. Shall we land this now, and when you add the change and doc for allreduce_base, could you please modify this accordingly if necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add allreduce_base on the refreshed PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is allreduce_base as allreduce already expects a just single tensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave it to @agolynski to comment on the plan for allreduce/allgather_base API. :)

prefix_store,
rank,
world_size,
group_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean all 3rd party backend must use the same signature? What about timeout? and why do we need to pass in group_name here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli understand your concern. At my local first version, I added timeout and no group_name, but later removed it as timeout is not necessary for our c10d backend.

I would add timeout back for all 3rd party backend.

@mrshenli
Copy link
Contributor

Hey @agolynski, could you please check if the API proposed in this PR looks OK to you? Thanks!

@ftian1
Copy link
Contributor Author

ftian1 commented Mar 10, 2020

@mrshenli @agolynski I updated the PR according to your comments, pls let me know if there needs further updates. many thanks for your patients

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for putting this together! Overall, it LGTM! I left some minor comments on documentation.

@agolynski will comment on the c10d API.

@@ -395,6 +395,24 @@ of 16
.. autofunction:: all_gather_multigpu


Third-party backends
--------------------

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we mark this as experimental for now? You can add sth like:

.. warning::
  The third-party backend API is experimental and subject to change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix it

return value

@classmethod
def register_backend(cls, name, func):
"""Registers a new backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it show properly in the doc? Could you please try the following and paste the screenshot here? Thanks!

cd docs/
pip install -r requirements.txt
make html

You can check the result by using this gist to serve your docs and paste a screenshot in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli, make html would fail with the below error.

FileNotFoundError: [Errno 2] No such file or directory: 'katex': 'katex'

but lucky thing is distributed.html has been generated. I paste the screenshot for your reference.
https://user-images.githubusercontent.com/16394660/76596980-70036280-653a-11ea-8d57-f36ee54894c2.png

image


Arguments:
name (str): Backend name matching with the one in `init_process_group()`.
func (function): Function handler that instantiates the backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we expect the function to take certain arguments (prefix_store, rank, world_size, and timeout), can we explain those args here as well?

Ideally, we might want to make this function first class citizen as well with clear signature and documents. But as we can mark this feature as experimental, we should be able to have some room to change it in the future when necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add comments here

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failure is real:

Mar 10 01:34:51 Traceback (most recent call last):
Mar 10 01:34:51   File "test_determination.py", line 97, in test_torch_file
Mar 10 01:34:51     "test_determination",
Mar 10 01:34:51 AssertionError: Lists differ: ['distributed/test_distributed', 'test_cpp_[75 chars]ion'] != ['test_cpp_extensions_aot_ninja', 'test_cpp[43 chars]ion']
Mar 10 01:34:51 

The original behavior of pytorch c10d only supports built-in backends, such as
nccl/gloo/mpi. This patch is used to extend the c10d capability to support 3rd party
communication libraries which are derived from ProcessGroup base class.

related RFC is in: pytorch#27955

Through this way, user just need manually import this backend and specify the backend
name when invoking torch.distributed.init_process_group(). The proposed logic will check
if the backend is registered through torch.distributed.Backend.register_backend().
As for how to develop a new 3rd party backend through cpp extension, pls refer to
test/cpp_extensions/cpp_c10d_extension.cpp
@ftian1
Copy link
Contributor Author

ftian1 commented Mar 13, 2020

@mrshenli I refine the doc/code according to your comment. pls let me know if it needs further update. thanks

@ddkalamk
Copy link
Contributor

@mrshenli @agolynski

May I request your time to make progress on this PR? As I mentioned earlier, this PR and #32361 are critical for our work with @srinivas212 and @dmudiger on distributed DLRM training on CPUs.
Also, with small change, this PR would enable building MPI backend to be used with binary distribution of pytorch saving pain to build whole pytorch from sources just to use MPI backend.

@mrshenli
Copy link
Contributor

Sorry about the delay. @agolynski is working on the last pass of the review and if everything looks good, he will land this soon.

Thanks for contributing!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agolynski has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@esaliya
Copy link
Contributor

esaliya commented Aug 4, 2021

@ftian1 I am trying out a third-party backend using a custom extension as implemented in this PR. However, new_group() creation does not send the ranks list in the sub-group to the registered method in the extension. Could you suggest how to get this? Here's the function call I am referring to

pg = getattr(Backend, backend.upper())(

@ftian1
Copy link
Contributor Author

ftian1 commented Aug 4, 2021

@esaliya could you pls refer to this example? https://github.com/intel/torch-ccl/blob/master/torch_ccl/csrc/init.cpp

it's a backend we added for this PR

@esaliya
Copy link
Contributor

esaliya commented Aug 4, 2021

Thanks @ftian1. I looked at it but it too doesn't have a way to get the list of ranks when creating a new_group(). For example,

dist.init_process_group('ccl', rank=3, world_size=8)
dist.new_group(ranks=[0,1,2])

in the second line, the list of ranks, [0,1,2] is passed to this https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#new_group but they are not passed to the ccl backend here

pg = getattr(Backend, backend.upper())(

Any thoughts on this?

@ftian1
Copy link
Contributor Author

ftian1 commented Aug 5, 2021

ok, I understand. when I added such dynamic loading mechanism I didn't take such sub_group/group_ranks into count.

so if your backend requests such capability, you have to contribute a PR to add this

@esaliya
Copy link
Contributor

esaliya commented Aug 5, 2021

Thanks, for confirming this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC][distributed] RFC: c10d ProcessGroup extension and C++ API change