Skip to content

Commit

Permalink
Clang format dist_utils.py and rpc/__init__.py (pytorch#56853)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#56853

ghstack-source-id: 127412640

Test Plan: N/A

Reviewed By: rohan-varma

Differential Revision: D27984669

fbshipit-source-id: 8e89ba0c53107622b3ca29ea296226e260b251df
  • Loading branch information
Yi Wang authored and Kushashwa Shrimali committed May 18, 2021
1 parent 7e75577 commit acbb5cd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
34 changes: 18 additions & 16 deletions torch/distributed/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import threading
import warnings

from typing import Generator, Tuple

import torch
import torch.distributed as dist

Expand All @@ -13,6 +13,7 @@
_init_counter = 0
_init_counter_lock = threading.Lock()


def is_available():
return hasattr(torch._C, "_rpc_init")

Expand All @@ -22,7 +23,7 @@ def is_available():


if is_available():
from . import api, backend_registry, functions
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import (
_disable_jit_rref_pickle,
_enable_jit_rref_pickle,
Expand Down Expand Up @@ -61,16 +62,18 @@ def is_available():
_UNSET_RPC_TIMEOUT,
_DEFAULT_RPC_TIMEOUT_SEC,
) # noqa: F401
from torch._C._distributed_c10d import Store

from . import api, backend_registry, functions
from .api import * # noqa: F401,F403
from .options import TensorPipeRpcBackendOptions # noqa: F401
import numbers

import torch.distributed.autograd as dist_autograd

from .backend_registry import BackendType
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .server_process_global_profiler import (
_server_process_global_profile,
)
import torch.distributed.autograd as dist_autograd

import numbers

rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]

Expand Down Expand Up @@ -111,12 +114,14 @@ def init_rpc(
are available.
"""

if backend is not None and not isinstance(backend, backend_registry.BackendType):
raise TypeError(
"Argument backend must be a member of BackendType"
)
if backend is not None and not isinstance(
backend, backend_registry.BackendType
):
raise TypeError("Argument backend must be a member of BackendType")

if rpc_backend_options is not None and not isinstance(rpc_backend_options, RpcBackendOptions):
if rpc_backend_options is not None and not isinstance(
rpc_backend_options, RpcBackendOptions
):
raise TypeError(
"Argument rpc_backend_options must be an instance of RpcBackendOptions"
)
Expand Down Expand Up @@ -182,7 +187,7 @@ def init_rpc(
# Use a PrefixStore to distinguish multiple invocations.
with _init_counter_lock:
global _init_counter
store = dist.PrefixStore(str('rpc_prefix_{}'.format(_init_counter)), store)
store = dist.PrefixStore(str("rpc_prefix_{}".format(_init_counter)), store)
_init_counter += 1

# Initialize autograd before RPC since _init_rpc_backend guarantees all
Expand All @@ -197,7 +202,6 @@ def init_rpc(
# Initialize RPC.
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)


def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
type_mapping = {
backend: backend_registry.BackendType,
Expand All @@ -215,7 +219,6 @@ def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_optio
)
)


def _init_rpc_backend(
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
store=None,
Expand All @@ -242,7 +245,6 @@ def _init_rpc_backend(

api._init_rpc_states(rpc_agent)


@api._require_initialized
def _get_debug_info():
info = _rref_context_get_debug_info()
Expand Down
44 changes: 31 additions & 13 deletions torch/testing/_internal/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import re
import sys
import time
Expand All @@ -24,22 +23,31 @@ def single_threaded_process_group_agent(f):
Forces ProcessGroupAgent to use only a single thread in the ThreadPool for
sending and processing requests.
"""

@wraps(f)
def wrapper(self, *args, **kwargs):
backend_type = self.rpc_backend
if backend_type == rpc.backend_registry.BackendType["PROCESS_GROUP"]:
self.rpc_backend_options = rpc.backend_registry.construct_rpc_backend_options(
self.rpc_backend,
init_method=self.init_method,
num_send_recv_threads=1,
self.rpc_backend_options = (
rpc.backend_registry.construct_rpc_backend_options(
self.rpc_backend,
init_method=self.init_method,
num_send_recv_threads=1,
)
)
return_value = f(self, *args, **kwargs)
return return_value

return wrapper


def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool = True,
faulty_messages=None, messages_to_delay=None):
def dist_init(
old_test_method=None,
setup_rpc: bool = True,
clean_shutdown: bool = True,
faulty_messages=None,
messages_to_delay=None,
):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
Expand Down Expand Up @@ -73,6 +81,7 @@ def new_test_method(self, *arg, **kwargs):
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
# in tests.
import torch.distributed.rpc.api as api

api._ignore_rref_leak = False

self.worker_id = self.rank
Expand Down Expand Up @@ -101,15 +110,16 @@ def new_test_method(self, *arg, **kwargs):
def noop() -> None:
pass


def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
'''
"""
Loops until an RPC to the given rank fails. This is used to
indicate that the node has failed in unit tests.
Args:
rank (int): Rank of the node expected to fail
expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
occurs, not just any.
'''
"""
while True:
try:
rpc.rpc_sync("worker{}".format(rank), noop, args=())
Expand All @@ -120,7 +130,7 @@ def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:


def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
'''
"""
The RRef protocol holds forkIds of rrefs in a map until those forks are
confirmed by the owner. The message confirming the fork may arrive after
our tests check whether this map is empty, which leads to failures and
Expand All @@ -129,7 +139,7 @@ def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
loops until the map is empty, which means the messages have been received
as processed. Call this function before asserting the map returned by
_get_debug_info is empty.
'''
"""
start = time.time()
while True:
debug_info = _rref_context_get_debug_info()
Expand Down Expand Up @@ -157,7 +167,9 @@ def get_num_owners_and_forks() -> Tuple[str, str]:
return num_owners, num_forks


def wait_until_owners_and_forks_on_rank(num_owners: int, num_forks: int, rank: int, timeout: int = 20) -> None:
def wait_until_owners_and_forks_on_rank(
num_owners: int, num_forks: int, rank: int, timeout: int = 20
) -> None:
"""
Waits until timeout for num_forks and num_owners to exist on the rank. Used
to ensure proper deletion of RRefs in tests.
Expand All @@ -175,7 +187,11 @@ def wait_until_owners_and_forks_on_rank(num_owners: int, num_forks: int, rank: i
if time.time() - start > timeout:
raise ValueError(
"Timed out waiting {} sec for {} owners and {} forks on rank, had {} owners and {} forks".format(
timeout, num_owners, num_forks, num_owners_on_rank, num_forks_on_rank
timeout,
num_owners,
num_forks,
num_owners_on_rank,
num_forks_on_rank,
)
)

Expand All @@ -192,9 +208,11 @@ def initialize_pg(init_method, rank: int, world_size: int) -> None:
world_size=world_size,
)


def worker_name(rank: int) -> str:
return "worker{}".format(rank)


def get_function_event(function_events, partial_event_name):
"""
Returns the first event that matches partial_event_name in the provided
Expand Down

0 comments on commit acbb5cd

Please sign in to comment.