Skip to content

Commit

Permalink
feat: add async model average algorithm (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying authored and NOBLES5E committed Aug 23, 2021
1 parent 2658a51 commit 96c7bd5
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 0 deletions.
1 change: 1 addition & 0 deletions bagua/torch_api/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .base import Algorithm # noqa: F401
from . import bytegrad, decentralized, gradient_allreduce # noqa: F401
from . import q_adam, async_model_average # noqa: F401
117 changes: 117 additions & 0 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm
from typing import List
from bagua.torch_api.tensor import BaguaTensor
import threading
import time
import logging
import os


def check_nccl_proto():
# TODO: remove nccl proto check
proto_str = os.environ.get("NCCL_PROTO", "")
if (
proto_str == ""
or ("^" not in proto_str and "LL128" in proto_str) # noqa: W503
or ("^" in proto_str and "LL128" not in proto_str) # noqa: W503
):
logging.warn(
"`LL128` proto for NCCL backend is not stable for async algorithms. Set `NCCL_PROTO=^LL128` to exclude it."
) # TODO; remove this after https://github.com/NVIDIA/nccl/issues/549 gets solved


class AsyncModelAverageAlgorithm(Algorithm):
def __init__(
self, peer_selection_mode: str = "all", sync_interval_ms: int = 500,
):
"""
Create an instance of the
`AsyncModelAverage <https://bagua-tutorials.kwai-seattle.com/algorithms/async-model-average.html>`_
algorithm.
The asynchronous implementation is experimental, and imposes some restrictions.
With such asynchronous algorithm, the number of iterations on each worker are different. Therefore
the current implementation assumes that the dataset is an endless stream, and all workers continuously
synchronize between each other.
Users should call :func:`abort` to manually stop the algorithm's continuous synchronization process.
Args:
peer_selection_mode (str): The way how workers communicate with each other. Currently "all" is supported.
"all" means all workers' weights are synchronized during each communication.
sync_interval_ms (int): Number of milliseconds between model synchronizations.
"""

self.peer_selection_mode = peer_selection_mode
self.sync_interval_ms = sync_interval_ms
self.stop_event = threading.Event()
check_nccl_proto()

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
self.tensors = [
param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
for name, param in parameters.__reversed__()
]
return self.tensors

def init_forward_pre_hook(self, bagua_module: BaguaModule):
def hook(input):
if not hasattr(self, "worker"):
self.worker = threading.Thread(
target=self.run_async_loop, args=[bagua_module]
)
self.worker.start()

return hook

def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
pass

return hook

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
pass

return hook

def init_operations(
self, bagua_module: BaguaModule, bucket: BaguaBucket,
):
bucket.clear_ops()
bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode,
)

def abort(self, bagua_module: BaguaModule, grace_period_seconds=5):
"""
Gracefully stop all workers.
Args:
bagua_module: A PyTorch module initialized by ``with_bagua(...)`` method.
grace_period_seconds: Number of seconds a worker will wait before aborting its unfinished communication operations.
"""
assert (
self.worker.is_alive() # pytype: disable=attribute-error
), "cannot abort since the asynchronous communication thread is not started"
self.stop_event.set()
time.sleep(grace_period_seconds)
bagua_module._bagua_backend.global_communicator.abort()

self.worker.join() # pytype: disable=attribute-error

def run_async_loop(self, bagua_module: BaguaModule):
while not self.stop_event.is_set():
if bagua_module.training:
for bucket in bagua_module.bagua_buckets:
for tensor in bucket.tensors:
tensor.bagua_mark_communication_ready_without_synchronization()

bagua_module._bagua_backend.wait_pending_comm_ops()

time.sleep(self.sync_interval_ms / 1000)
20 changes: 20 additions & 0 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,26 @@ def append_low_precision_decentralized_synchronous_op(

return self

def append_asynchronous_model_average_op(self, peer_selection_mode: str):
"""
Append an asynchronous model average operation to a bucket. This operation will enable continuous model averaging between workers
while training a model.
The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.
Args:
peer_selection_mode (str): The way how workers communicate with each otehr. Currently "all" is supported.
"all" means all workers' weights are averaged during each communication.
"""

self.backend_bucket.append_decentralized_asynchronous_op(
self._bagua_backend.global_communicator,
None,
peer_selection_mode=peer_selection_mode,
torch_stream=torch.cuda.current_stream().cuda_stream,
)
return self

def clear_ops(self) -> BaguaBucket:
"""
Clear the previously appended operations.
Expand Down
80 changes: 80 additions & 0 deletions tests/torch_api/test_async_model_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tests.internal.common_utils import find_free_port
import unittest
import torch.multiprocessing as mp
import os
import bagua.torch_api as bagua


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 50, bias=True)
self.fc3 = nn.Linear(50, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, dim=1)


def run_model(rank):
# initialize subprocess env
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)

# init bagua distributed process group
torch.cuda.set_device(rank)
bagua.init_process_group()

# construct model and optimizer, etc.
model = Net().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# wrap model
algorithm = bagua.algorithms.async_model_average.AsyncModelAverageAlgorithm()
model = model.with_bagua(
[optimizer], algorithm
)

for _ in range(10):
data = torch.randn(4, 2).cuda()
target = torch.randn(4, 4).cuda()

optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)

loss.backward()
optimizer.step()

algorithm.abort(model)


class TestAsyncModelAverage(unittest.TestCase):
def test_algorithm(self):
if not torch.cuda.is_available():
print("skip tests since cuda is not available")
return

nprocs = torch.cuda.device_count()
os.environ["WORLD_SIZE"] = str(nprocs)
os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(find_free_port())
os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port())

mp.spawn(
run_model,
nprocs=nprocs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 96c7bd5

Please sign in to comment.