diff --git a/tests/internal/multi_process_v2.py b/tests/internal/multi_process_v2.py new file mode 100644 index 000000000..77a67983b --- /dev/null +++ b/tests/internal/multi_process_v2.py @@ -0,0 +1,421 @@ +# This file is modified on https://github.com/pytorch/pytorch/blob/v1.13.0/torch/testing/_internal/common_distributed.py +import faulthandler +import logging +import multiprocessing +import os +import pickle +import sys +import tempfile +import threading +import time +import traceback +import types +import unittest + +from enum import Enum +from functools import wraps +from typing import NamedTuple + +import torch +import bagua.torch_api as bagua + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class TestResult(NamedTuple): + exit_code: int + message: str + + +TEST_SKIPS = { + "multi-gpu-1": TestResult(75, "Need at least 1 CUDA device"), + "multi-gpu-2": TestResult(77, "Need at least 2 CUDA devices"), + "multi-gpu-3": TestResult(80, "Need at least 3 CUDA devices"), + "multi-gpu-4": TestResult(81, "Need at least 4 CUDA devices"), + "multi-gpu-5": TestResult(82, "Need at least 5 CUDA devices"), + "multi-gpu-6": TestResult(83, "Need at least 6 CUDA devices"), + "multi-gpu-7": TestResult(84, "Need at least 7 CUDA devices"), + "multi-gpu-8": TestResult(85, "Need at least 8 CUDA devices"), + "generic": TestResult( + 86, "Test skipped at subprocess level, look at subprocess log for skip reason" + ), +} + + +def make_success_result(msg: str): + return TestResult(0, msg) + + +def make_error_result(msg: str): + return TestResult(255, msg) + + +def skip_if_lt_x_gpu(x): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator + + +# [How does MultiProcessTestCase work?] +# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by +# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an +# example which inherits from this class. Its `Setup()` methods calls into +# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` +# subprocesses. During the spawn, the main process passes the test name to +# subprocesses, and the name is acquired from self.id(). The subprocesses +# then use the provided test function name to retrieve the function attribute +# from the test instance and run it. The main process simply waits for all +# subprocesses to join. + + +class MultiProcessTestCase(unittest.TestCase): + MAIN_PROCESS_RANK = -1 + # This exit code is used to indicate that the test code had an error and + # exited abnormally. There are certain tests that might use sys.exit() to + # simulate failures and in those cases, we can't have an exit code of 0, + # but we still want to ensure we didn't run into any other errors. + TEST_ERROR_EXIT_CODE = 10 + + def _get_timeout(self): + return 300 + + def _init_bagua_distributed(self): + logger.info("rank: {}, world_size: {}".format(self.rank, self.world_size)) + + torch.cuda.set_device(self.rank) + store = torch.distributed.FileStore(self.file_name, self.world_size) + bagua.init_process_group( + store, + rank=self.rank, + world_size=self.world_size, + local_world_size=self.world_size, + ) + + @property + def world_size(self) -> int: + return 4 + + def join_or_run(self, fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + return self._join_processes(fn) + else: + return fn() + + return types.MethodType(wrapper, self) + + # The main process spawns N subprocesses that run the test. + # Constructor patches current instance test method to + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + def __init__(self, method_name: str = "runTest") -> None: + super().__init__(method_name) + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + + def setUp(self) -> None: + super().setUp() + self.skip_return_code_checks = [] # type: ignore[var-annotated] + self.processes = [] # type: ignore[var-annotated] + self.rank = self.MAIN_PROCESS_RANK + self.file_name = tempfile.NamedTemporaryFile(delete=False).name + # pid to pipe consisting of error message from process. + self.pid_to_pipe = {} # type: ignore[var-annotated] + + def tearDown(self) -> None: + super().tearDown() + for p in self.processes: + p.terminate() + # Each Process instance holds a few open file descriptors. The unittest + # runner creates a new TestCase instance for each test method and keeps + # it alive until the end of the entire suite. We must thus reset the + # processes to prevent an effective file descriptor leak. + self.processes = [] + + def _check_result(self, test_id=None): + pass + + def _current_test_name(self) -> str: + # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + return self.id().split(".")[-1] + + def _start_processes(self, proc) -> None: + self.processes = [] + for rank in range(int(self.world_size)): + parent_conn, child_conn = torch.multiprocessing.Pipe() + process = proc( + target=self.__class__._run, + name="process " + str(rank), + args=(rank, self._current_test_name(), self.file_name, child_conn), + ) + process.start() + logger.info(f"Started process {rank} with pid {process.pid}") + self.pid_to_pipe[process.pid] = parent_conn + self.processes.append(process) + + def _spawn_processes(self) -> None: + proc = torch.multiprocessing.get_context("spawn").Process + self._start_processes(proc) + + class Event(Enum): + GET_TRACEBACK = 1 + + @staticmethod + def _event_listener(parent_pipe, signal_pipe, rank: int): + logger.info(f"Starting event listener thread for rank {rank}") + while True: + ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe]) + + if parent_pipe in ready_pipes: + if parent_pipe.closed: + logger.info( + f"Pipe closed for process {rank}, stopping event listener thread" + ) + return + + event = parent_pipe.recv() + logger.info(f"Received event {event} on process {rank}") + + if event == MultiProcessTestCase.Event.GET_TRACEBACK: + # Return traceback to the parent process. + with tempfile.NamedTemporaryFile(mode="r+") as tmp_file: + faulthandler.dump_traceback(tmp_file) + # Flush buffers and seek to read from the beginning + tmp_file.flush() + tmp_file.seek(0) + parent_pipe.send(make_error_result(tmp_file.read())) + + logger.info(f"Process {rank} sent traceback") + + if signal_pipe in ready_pipes: + return + + @classmethod + def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: + # Enable DDP + ReplicatedTensor + from torch.nn.parallel._replicated_tensor_ddp_utils import ( + _set_ddp_with_replicated_tensor, + ) + + _set_ddp_with_replicated_tensor(True) + + self = cls(test_name) + + self.rank = rank + self.file_name = file_name + self.run_test(test_name, parent_pipe) + + def run_test(self, test_name: str, parent_pipe) -> None: + # Start event listener thread. + signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False) + event_listener_thread = threading.Thread( + target=MultiProcessTestCase._event_listener, + args=(parent_pipe, signal_recv_pipe, self.rank), + daemon=True, + ) + event_listener_thread.start() + if sys.platform != "win32" and sys.platform != "darwin": + # Register signal handler to dump stack traces on FATALs. + # Windows and MacOS do not support the signal handlers. + torch._C._set_print_stack_traces_on_fatal_signal(True) + # Show full C++ stacktraces when a Python error originating from C++ is raised. + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retrieving a corresponding test and executing it. + ret = None + try: + ret = getattr(self, test_name)() + except unittest.SkipTest as se: + logger.info( + f"Process {self.rank} skipping test {test_name} for following reason: {str(se)}" + ) + sys.exit(TEST_SKIPS["generic"].exit_code) + except Exception: + logger.error( + f"Caught exception: \n{traceback.format_exc()} exiting " + f"process {self.rank} with exit code: {MultiProcessTestCase.TEST_ERROR_EXIT_CODE}" + ) + # Send error to parent process. + parent_pipe.send(make_error_result(traceback.format_exc())) + sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) + finally: + if ret is not None: + parent_pipe.send(make_success_result(pickle.dumps(ret))) + + if signal_send_pipe is not None: + signal_send_pipe.send(None) + + assert event_listener_thread is not None + event_listener_thread.join() + # Close pipe after done with test. + parent_pipe.close() + + def _get_timedout_process_traceback(self) -> None: + pipes = [] + for i, process in enumerate(self.processes): + if process.exitcode is None: + pipe = self.pid_to_pipe[process.pid] + try: + pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) + pipes.append((i, pipe)) + except ConnectionError as e: + logger.error( + f"Encountered error while trying to get traceback for process {i}: {e}" + ) + + # Wait for results. + for rank, pipe in pipes: + try: + # Wait for traceback + if pipe.poll(5): + if pipe.closed: + logger.info( + f"Pipe closed for process {rank}, cannot retrieve traceback" + ) + continue + + traceback = pipe.recv() + logger.error( + f"Process {rank} timed out with traceback: \n\n{traceback}" + ) + else: + logger.error( + f"Could not retrieve traceback for timed out process: {rank}" + ) + except ConnectionError as e: + logger.error( + f"Encountered error while trying to get traceback for process {rank}: {e}" + ) + + def _join_processes(self, fn) -> None: + timeout = self._get_timeout() + start_time = time.time() + subprocess_error = False + try: + while True: + # check to see if any subprocess exited with an error early. + for i, p in enumerate(self.processes): + # This is the exit code processes exit with if they + # encountered an exception. + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: + print( + f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes." + ) + active_children = torch.multiprocessing.active_children() + for ac in active_children: + ac.terminate() + subprocess_error = True + break + if subprocess_error: + break + + # All processes have joined cleanly if they all a valid exitcode + if all([p.exitcode is not None for p in self.processes]): + break + + # Check if we should time out the test. If so, we terminate each process. + elapsed = time.time() - start_time + if elapsed > timeout: + self._get_timedout_process_traceback() + print( + f"Timing out after {timeout} seconds and killing subprocesses." + ) + for p in self.processes: + p.terminate() + break + # Sleep to avoid excessive busy polling. + time.sleep(0.1) + + elapsed_time = time.time() - start_time + + if fn in self.skip_return_code_checks: + self._check_no_test_errors(elapsed_time) + else: + self._check_return_codes(elapsed_time) + finally: + # Close all pipes + for pid, pipe in self.pid_to_pipe.items(): + pipe.close() + + def _check_no_test_errors(self, elapsed_time) -> None: + """ + Checks that we didn't have any errors thrown in the child processes. + """ + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError( + "Process {} timed out after {} seconds".format(i, elapsed_time) + ) + self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) + + def _check_return_codes(self, elapsed_time) -> None: + """ + Checks that the return codes of all spawned processes match, and skips + tests if they returned a return code indicating a skipping condition. + """ + first_process = self.processes[0] + # first, we check if there are errors in actual processes + # (via TEST_ERROR_EXIT CODE), and raise an exception for those. + # the reason we do this is to attempt to raise a more helpful error + # message than "Process x terminated/timed out" + # TODO: we should pipe the exception of the failed subprocess here. + # Currently, the actual exception is displayed as a logging output. + errored_processes = [ + (i, p) + for i, p in enumerate(self.processes) + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE + ] + if errored_processes: + error = "" + for i, process in errored_processes: + # Get error from pipe. + error_message = self.pid_to_pipe[process.pid].recv() + error += ( + "Process {} exited with error code {} and exception:\n{}\n".format( + i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, error_message + ) + ) + + raise RuntimeError(error) + # If no process exited uncleanly, we check for timeouts, and then ensure + # each process exited cleanly. + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError( + "Process {} terminated or timed out after {} seconds".format( + i, elapsed_time + ) + ) + self.assertEqual( + p.exitcode, + first_process.exitcode, + msg="Expect process {} exit code to match Process 0 exit code of {}, but got {}".format( + i, first_process.exitcode, p.exitcode + ), + ) + for skip in TEST_SKIPS.values(): + if first_process.exitcode == skip.exit_code: + raise unittest.SkipTest(skip.message) + + self.assertEqual( + first_process.exitcode, + 0, + msg="Expected zero exit code but got {} for pid: {}".format( + first_process.exitcode, first_process.pid + ), + ) + self._check_result(self._current_test_name()) + + @property + def is_master(self) -> bool: + return self.rank == 0 diff --git a/tests/torch_api/data_parallel/test_async_model_average.py b/tests/torch_api/data_parallel/test_async_model_average.py index 9ac3ebfd3..a37dfbcb9 100644 --- a/tests/torch_api/data_parallel/test_async_model_average.py +++ b/tests/torch_api/data_parallel/test_async_model_average.py @@ -1,14 +1,16 @@ +import logging +import os +import unittest + import torch import torch.nn as nn import torch.nn.functional as F -from tests.internal.common_utils import find_free_port -import unittest -import multiprocessing -import os import bagua.torch_api as bagua -from tests import skip_if_cuda_not_available -import logging + from bagua.torch_api.data_parallel import DistributedDataParallel as DDP +from tests.internal.multi_process_v2 import MultiProcessTestCase, skip_if_lt_x_gpu + +logger = logging.getLogger(__name__) class Net(nn.Module): @@ -26,24 +28,10 @@ def forward(self, x): return F.softmax(x, dim=1) -def run_model_wrapper(rank, env, fn, warmup_steps): - # initialize subprocess env - os.environ["WORLD_SIZE"] = env["WORLD_SIZE"] - os.environ["LOCAL_WORLD_SIZE"] = env["LOCAL_WORLD_SIZE"] - os.environ["MASTER_ADDR"] = env["MASTER_ADDR"] - os.environ["MASTER_PORT"] = env["MASTER_PORT"] - os.environ["BAGUA_SERVICE_PORT"] = env["BAGUA_SERVICE_PORT"] - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - - # init bagua distributed process group - torch.cuda.set_device(rank) - bagua.init_process_group() - - # construct model and optimizer, etc. +def create_model_and_optimizer(warmup_steps): + # construct model and optimizer model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - loss_fn = nn.MSELoss() # wrap model algorithm = bagua.algorithms.async_model_average.AsyncModelAverageAlgorithm( @@ -52,84 +40,62 @@ def run_model_wrapper(rank, env, fn, warmup_steps): ) ddp_model = DDP(model, optimizers=[optimizer], algorithm=algorithm) - fn(ddp_model, optimizer, loss_fn) + return ddp_model, optimizer -def train_epoch(epoch, model, optimizer, loss_fn): - logging.debug("Training epoch {}".format(epoch)) +def train_epoch(epoch, model, optimizer): + logger.debug("Training epoch {}".format(epoch)) for _ in range(10): data = torch.randn(4, 2).cuda() target = torch.randn(4, 4).cuda() optimizer.zero_grad() output = model(data) - loss = loss_fn(output, target) + loss = nn.MSELoss()(output, target) loss.backward() optimizer.step() -def run_epochs(model, optimizer, loss_fn): - for epoch in range(5): - train_epoch(epoch, model, optimizer, loss_fn) - model.bagua_algorithm.abort(model) +class TestAsyncModelAverage(MultiProcessTestCase): + def setUp(self): + super(TestAsyncModelAverage, self).setUp() + self._spawn_processes() + def tearDown(self): + super(TestAsyncModelAverage, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass -def run_multiple_aborts(model, optimizer, loss_fn): - for epoch in range(10): - model.bagua_algorithm.resume(model) - model.bagua_algorithm.resume(model) - train_epoch(epoch, model, optimizer, loss_fn) - model.bagua_algorithm.abort(model) - model.bagua_algorithm.abort(model) - + @property + def world_size(self) -> int: + return 4 -class TestAsyncModelAverage(unittest.TestCase): - @skip_if_cuda_not_available() + @skip_if_lt_x_gpu(4) def test_algorithm(self): - nprocs = torch.cuda.device_count() - env = { - "WORLD_SIZE": str(nprocs), - "LOCAL_WORLD_SIZE": str(nprocs), - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(find_free_port(8000, 8100)), - "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)), - } - - mp = multiprocessing.get_context("spawn") - processes = [] - for i in range(nprocs): - p = mp.Process(target=run_model_wrapper, args=(i, env, run_epochs, 0)) - p.start() - processes.append(p) - - for p in processes: - p.join(timeout=60) - self.assertTrue(p.exitcode == 0) - - @skip_if_cuda_not_available() + self._init_bagua_distributed() + model, optimizer = create_model_and_optimizer(warmup_steps=0) + + for epoch in range(100): + train_epoch(epoch, model, optimizer) + model.bagua_algorithm.abort(model) + + @skip_if_lt_x_gpu(2) def test_multiple_aborts(self): - nprocs = torch.cuda.device_count() - env = { - "WORLD_SIZE": str(nprocs), - "LOCAL_WORLD_SIZE": str(nprocs), - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(find_free_port(8000, 8100)), - "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)), - } - - mp = multiprocessing.get_context("spawn") - processes = [] - for i in range(nprocs): - p = mp.Process( - target=run_model_wrapper, args=(i, env, run_multiple_aborts, 10) - ) - p.start() - processes.append(p) - - for p in processes: - p.join(timeout=60) - self.assertTrue(p.exitcode == 0) + self._init_bagua_distributed() + model, optimizer = create_model_and_optimizer(warmup_steps=10) + + for i in range(2): + model.bagua_algorithm.resume(model) + model.bagua_algorithm.abort(model) + model.bagua_algorithm.resume(model) + for epoch in range(100): + train_epoch(i * 100 + epoch, model, optimizer) + + model.bagua_algorithm.abort(model) + model.bagua_algorithm.abort(model) if __name__ == "__main__": diff --git a/tests/torch_api/data_parallel/test_gradient_allreduce.py b/tests/torch_api/data_parallel/test_gradient_allreduce.py index 5bcfd15ae..3c1367916 100644 --- a/tests/torch_api/data_parallel/test_gradient_allreduce.py +++ b/tests/torch_api/data_parallel/test_gradient_allreduce.py @@ -1,14 +1,17 @@ +import logging +import os +import pickle +import unittest + import torch import torch.nn as nn import torch.nn.functional as F -from tests.internal.common_utils import find_free_port -from tests.internal.multi_process import setup_bagua_env -import unittest -import multiprocessing -from bagua.torch_api.utils import flatten import bagua.torch_api as bagua -from tests import skip_if_cuda_not_available + from bagua.torch_api.data_parallel import DistributedDataParallel as DDP +from tests.internal.multi_process_v2 import MultiProcessTestCase, skip_if_lt_x_gpu + +logger = logging.getLogger(__name__) class Net(nn.Module): @@ -26,31 +29,14 @@ def forward(self, x): return F.softmax(x, dim=1) -def _init_bagua_env(rank, env): - # set deterministic - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.manual_seed(rank) - # initialize subprocess env - setup_bagua_env(rank, env) - - -def run_model( - rank, - nprocs, - hierarchical, - results, - env, -): - _init_bagua_env(rank, env) - +def run_model(hierarchical): # construct model and optimizer, etc. model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) loss_fn = nn.MSELoss() def run_epochs(num_epochs): - for epoch in range(num_epochs): + for _ in range(num_epochs): data = torch.randn(4, 2).cuda() target = torch.randn(4, 4).cuda() @@ -74,71 +60,58 @@ def run_epochs(num_epochs): run_epochs(10) - ret = results[rank] + flattened_weight = bagua.utils.flatten([param.data for param in model.parameters()]) + weight_norm = flattened_weight.norm().item() + return weight_norm - ret._weight.copy_(flatten([param.data for param in model.parameters()])) +class TestGradientAllReduce(MultiProcessTestCase): + def setUp(self): + super(TestGradientAllReduce, self).setUp() + self._spawn_processes() -class Result(object): - def __init__(self): - model = Net() - self._weight = flatten( - [torch.zeros_like(param.data) for param in model.parameters()] - ) - - -class TestGradientAllReduce(unittest.TestCase): - def run_test_locally( - self, - nprocs, - hierarchical, - ): - env = { - "WORLD_SIZE": str(nprocs), - "LOCAL_WORLD_SIZE": str(nprocs), - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(find_free_port(8000, 8100)), - "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)), - } - - mp = multiprocessing.get_context("spawn") - results = [Result() for _ in range(nprocs)] - processes = [] - for i in range(nprocs): - p = mp.Process( - target=run_model, - args=( - i, - nprocs, - hierarchical, - results, - env, - ), - ) - p.start() - processes.append(p) - - for p in processes: - p.join(timeout=60) - self.assertTrue(p.exitcode == 0) - - for rank in range(nprocs): - peer_rank = (rank + 1) % nprocs - # all workers have equal weights - self.assertTrue( - torch.equal( - results[rank]._weight, - results[peer_rank]._weight, - ) - ) - - @skip_if_cuda_not_available() + def tearDown(self): + super(TestGradientAllReduce, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _check_result(self, test_id=None): + result = None + for i, process in enumerate(self.processes): + _, msg = self.pid_to_pipe[process.pid].recv() + weight_norm = pickle.loads(msg) + + logger.info("process {} result: {}".format(i, weight_norm)) + if result is None: + result = weight_norm + else: + assert result == weight_norm + + @property + def world_size(self) -> int: + return 4 + + @skip_if_lt_x_gpu(4) def test_algorithm(self): - nprocs = torch.cuda.device_count() - self.run_test_locally( - nprocs=nprocs, - hierarchical=False, - ) + # set deterministic + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(self.rank) + + self._init_bagua_distributed() + return run_model(hierarchical=False) + + @skip_if_lt_x_gpu(4) + def test_algorithm_hierarchical(self): + # set deterministic + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(self.rank) + + self._init_bagua_distributed() + return run_model(hierarchical=True) if __name__ == "__main__": diff --git a/tests/torch_api/test_async_model_average.py b/tests/torch_api/test_async_model_average.py index 6b2de46d2..538f50d9e 100644 --- a/tests/torch_api/test_async_model_average.py +++ b/tests/torch_api/test_async_model_average.py @@ -1,13 +1,15 @@ +import logging +import os +import unittest + import torch import torch.nn as nn import torch.nn.functional as F -from tests.internal.common_utils import find_free_port -import unittest -import multiprocessing -import os import bagua.torch_api as bagua -from tests import skip_if_cuda_not_available -import logging + +from tests.internal.multi_process_v2 import MultiProcessTestCase, skip_if_lt_x_gpu + +logger = logging.getLogger(__name__) class Net(nn.Module): @@ -25,24 +27,10 @@ def forward(self, x): return F.softmax(x, dim=1) -def run_model_wrapper(rank, env, fn, warmup_steps): - # initialize subprocess env - os.environ["WORLD_SIZE"] = env["WORLD_SIZE"] - os.environ["LOCAL_WORLD_SIZE"] = env["LOCAL_WORLD_SIZE"] - os.environ["MASTER_ADDR"] = env["MASTER_ADDR"] - os.environ["MASTER_PORT"] = env["MASTER_PORT"] - os.environ["BAGUA_SERVICE_PORT"] = env["BAGUA_SERVICE_PORT"] - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - - # init bagua distributed process group - torch.cuda.set_device(rank) - bagua.init_process_group() - - # construct model and optimizer, etc. +def create_model_and_optimizer(warmup_steps): + # construct model and optimizer model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - loss_fn = nn.MSELoss() # wrap model algorithm = bagua.algorithms.async_model_average.AsyncModelAverageAlgorithm( @@ -51,87 +39,62 @@ def run_model_wrapper(rank, env, fn, warmup_steps): ) model = model.with_bagua([optimizer], algorithm) - fn(model, optimizer, loss_fn) + return model, optimizer -def train_epoch(epoch, model, optimizer, loss_fn): - logging.debug("Training epoch {}".format(epoch)) +def train_epoch(epoch, model, optimizer): + logger.debug("Training epoch {}".format(epoch)) for _ in range(10): data = torch.randn(4, 2).cuda() target = torch.randn(4, 4).cuda() optimizer.zero_grad() output = model(data) - loss = loss_fn(output, target) + loss = nn.MSELoss()(output, target) loss.backward() optimizer.step() -def run_epochs(model, optimizer, loss_fn): - for epoch in range(100): - train_epoch(epoch, model, optimizer, loss_fn) - model.bagua_algorithm.abort(model) +class TestAsyncModelAverage(MultiProcessTestCase): + def setUp(self): + super(TestAsyncModelAverage, self).setUp() + self._spawn_processes() + def tearDown(self): + super(TestAsyncModelAverage, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass -def run_multiple_aborts(model, optimizer, loss_fn): - for epoch in range(2): - model.bagua_algorithm.resume(model) - model.bagua_algorithm.abort(model) - model.bagua_algorithm.resume(model) - for _ in range(100): - train_epoch(epoch, model, optimizer, loss_fn) + @property + def world_size(self) -> int: + return 4 - model.bagua_algorithm.abort(model) - model.bagua_algorithm.abort(model) + @skip_if_lt_x_gpu(4) + def test_algorithm(self): + self._init_bagua_distributed() + model, optimizer = create_model_and_optimizer(warmup_steps=0) + for epoch in range(100): + train_epoch(epoch, model, optimizer) + model.bagua_algorithm.abort(model) -class TestAsyncModelAverage(unittest.TestCase): - @skip_if_cuda_not_available() - def test_algorithm(self): - nprocs = torch.cuda.device_count() - env = { - "WORLD_SIZE": str(nprocs), - "LOCAL_WORLD_SIZE": str(nprocs), - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(find_free_port(8000, 8100)), - "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)), - } - - mp = multiprocessing.get_context("spawn") - processes = [] - for i in range(nprocs): - p = mp.Process(target=run_model_wrapper, args=(i, env, run_epochs, 0)) - p.start() - processes.append(p) - - for p in processes: - p.join(timeout=60) - self.assertTrue(p.exitcode == 0) - - @skip_if_cuda_not_available() + @skip_if_lt_x_gpu(4) def test_multiple_aborts(self): - nprocs = torch.cuda.device_count() - env = { - "WORLD_SIZE": str(nprocs), - "LOCAL_WORLD_SIZE": str(nprocs), - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(find_free_port(8000, 8100)), - "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)), - } - - mp = multiprocessing.get_context("spawn") - processes = [] - for i in range(nprocs): - p = mp.Process( - target=run_model_wrapper, args=(i, env, run_multiple_aborts, 10) - ) - p.start() - processes.append(p) - - for p in processes: - p.join(timeout=60) - self.assertTrue(p.exitcode == 0) + self._init_bagua_distributed() + model, optimizer = create_model_and_optimizer(warmup_steps=10) + + for i in range(2): + model.bagua_algorithm.resume(model) + model.bagua_algorithm.abort(model) + model.bagua_algorithm.resume(model) + for epoch in range(100): + train_epoch(i * 100 + epoch, model, optimizer) + + model.bagua_algorithm.abort(model) + model.bagua_algorithm.abort(model) if __name__ == "__main__": diff --git a/tests/torch_api/test_gradient_allreduce.py b/tests/torch_api/test_gradient_allreduce.py index 452755a73..138ff3fb8 100644 --- a/tests/torch_api/test_gradient_allreduce.py +++ b/tests/torch_api/test_gradient_allreduce.py @@ -1,13 +1,16 @@ +import logging +import os +import pickle +import unittest + import torch import torch.nn as nn import torch.nn.functional as F -from tests.internal.common_utils import find_free_port -from tests.internal.multi_process import setup_bagua_env -import unittest -import multiprocessing -from bagua.torch_api.utils import flatten import bagua.torch_api as bagua -from tests import skip_if_cuda_not_available + +from tests.internal.multi_process_v2 import MultiProcessTestCase, skip_if_lt_x_gpu + +logger = logging.getLogger(__name__) class Net(nn.Module): @@ -25,31 +28,14 @@ def forward(self, x): return F.softmax(x, dim=1) -def _init_bagua_env(rank, env): - # set deterministic - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.manual_seed(rank) - # initialize subprocess env - setup_bagua_env(rank, env) - - -def run_model( - rank, - nprocs, - hierarchical, - results, - env, -): - _init_bagua_env(rank, env) - +def run_model(hierarchical): # construct model and optimizer, etc. model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) loss_fn = nn.MSELoss() def run_epochs(num_epochs): - for epoch in range(num_epochs): + for _ in range(num_epochs): data = torch.randn(4, 2).cuda() target = torch.randn(4, 4).cuda() @@ -72,71 +58,58 @@ def run_epochs(num_epochs): run_epochs(10) - ret = results[rank] + flattened_weight = bagua.utils.flatten([param.data for param in model.parameters()]) + weight_norm = flattened_weight.norm().item() + return weight_norm - ret._weight.copy_(flatten([param.data for param in model.parameters()])) +class TestGradientAllReduce(MultiProcessTestCase): + def setUp(self): + super(TestGradientAllReduce, self).setUp() + self._spawn_processes() -class Result(object): - def __init__(self): - model = Net() - self._weight = flatten( - [torch.zeros_like(param.data) for param in model.parameters()] - ) - - -class TestGradientAllReduce(unittest.TestCase): - def run_test_locally( - self, - nprocs, - hierarchical, - ): - env = { - "WORLD_SIZE": str(nprocs), - "LOCAL_WORLD_SIZE": str(nprocs), - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(find_free_port(8000, 8100)), - "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)), - } - - mp = multiprocessing.get_context("spawn") - results = [Result() for _ in range(nprocs)] - processes = [] - for i in range(nprocs): - p = mp.Process( - target=run_model, - args=( - i, - nprocs, - hierarchical, - results, - env, - ), - ) - p.start() - processes.append(p) - - for p in processes: - p.join(timeout=60) - self.assertTrue(p.exitcode == 0) - - for rank in range(nprocs): - peer_rank = (rank + 1) % nprocs - # all workers have equal weights - self.assertTrue( - torch.equal( - results[rank]._weight, - results[peer_rank]._weight, - ) - ) - - @skip_if_cuda_not_available() + def tearDown(self): + super(TestGradientAllReduce, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _check_result(self, test_id=None): + result = None + for i, process in enumerate(self.processes): + _, msg = self.pid_to_pipe[process.pid].recv() + weight_norm = pickle.loads(msg) + + logger.info("process {} result: {}".format(i, weight_norm)) + if result is None: + result = weight_norm + else: + assert result == weight_norm + + @property + def world_size(self) -> int: + return 4 + + @skip_if_lt_x_gpu(4) def test_algorithm(self): - nprocs = torch.cuda.device_count() - self.run_test_locally( - nprocs=nprocs, - hierarchical=False, - ) + # set deterministic + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(self.rank) + + self._init_bagua_distributed() + return run_model(hierarchical=False) + + @skip_if_lt_x_gpu(4) + def test_algorithm_hierarchical(self): + # set deterministic + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(self.rank) + + self._init_bagua_distributed() + return run_model(hierarchical=True) if __name__ == "__main__":