Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add async model average algorithm #110

Merged
merged 34 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
27411a9
chore(release): v0.6.2
NOBLES5E Jul 2, 2021
1fa5f1f
Merge branch 'master' of github.com:BaguaSys/bagua
NOBLES5E Jul 2, 2021
a7e334a
Merge branch 'master' of github.com:BaguaSys/bagua
NOBLES5E Jul 7, 2021
69e8bf2
update
NOBLES5E Jul 7, 2021
de05ee6
Merge branch 'master' into feat/async-model-average
wangraying Jul 24, 2021
3e65319
update
wangraying Jul 25, 2021
ee40323
add async
wangraying Aug 6, 2021
adcb1e3
add comment
wangraying Aug 6, 2021
cb6ad87
Update bagua/torch_api/bucket.py
wangraying Aug 6, 2021
5b5f793
update
wangraying Aug 6, 2021
7196a3d
update
wangraying Aug 6, 2021
891fa9a
Merge branch 'master' into feat/async-model-average
wangraying Aug 6, 2021
7625f6f
Apply suggestions from code review
wangraying Aug 6, 2021
b4285c8
Apply suggestions from code review
wangraying Aug 6, 2021
6c3274b
.
wangraying Aug 6, 2021
9460589
fix
wangraying Aug 9, 2021
e97e17d
.
wangraying Aug 10, 2021
42a4143
Apply suggestions from code review
wangraying Aug 12, 2021
edc56ec
little change
wangraying Aug 13, 2021
1acd2d5
Merge branch 'feat/async-model-average' of https://github.com/BaguaSy…
wangraying Aug 13, 2021
05ab250
..
wangraying Aug 13, 2021
e3379ee
Update async_model_average.py
NOBLES5E Aug 19, 2021
c3676f7
Merge branch 'master' into feat/async-model-average
wangraying Aug 19, 2021
1c40edc
fix pytype
wangraying Aug 19, 2021
f550bcd
add
wangraying Aug 20, 2021
8dc05a2
fmt
wangraying Aug 20, 2021
28c85c3
fmt
wangraying Aug 20, 2021
af18148
.
wangraying Aug 20, 2021
1c33cf9
Update async_model_average.py
NOBLES5E Aug 23, 2021
7246854
Update async_model_average.py
NOBLES5E Aug 23, 2021
7dd0526
Update bucket.py
NOBLES5E Aug 23, 2021
27fa44b
Update async_model_average.py
NOBLES5E Aug 23, 2021
e8446d8
Apply suggestions from code review
wangraying Aug 23, 2021
cf7b15f
uu
wangraying Aug 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bagua/torch_api/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .base import Algorithm # noqa: F401
from . import bytegrad, decentralized, gradient_allreduce # noqa: F401
from . import q_adam, async_model_average # noqa: F401
117 changes: 117 additions & 0 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm
from typing import List
from bagua.torch_api.tensor import BaguaTensor
import threading
import time
import logging
import os


def check_nccl_proto():
# TODO: remove nccl proto check
proto_str = os.environ.get("NCCL_PROTO", "")
if (
proto_str == ""
or ("^" not in proto_str and "LL128" in proto_str) # noqa: W503
or ("^" in proto_str and "LL128" not in proto_str) # noqa: W503
):
logging.warn(
"`LL128` proto for NCCL backend is not stable for async algorithms. Set `NCCL_PROTO=^LL128` to exclude it."
) # TODO; remove this after https://github.com/NVIDIA/nccl/issues/549 gets solved


class AsyncModelAverageAlgorithm(Algorithm):
def __init__(
self, peer_selection_mode: str = "all", sync_interval_ms: int = 500,
wangraying marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Create an instance of the
`AsyncModelAverage <https://bagua-tutorials.kwai-seattle.com/algorithms/async-model-average.html>`_
algorithm.

The asynchronous implementation is experimental, and imposes some restrictions.
With such asynchronous algorithm, the number of iterations on each worker are different. Therefore
the current implementation assumes that the dataset is an endless stream, and all workers continuously
synchronize between each other.

Users should call :func:`abort` to manually stop the algorithm's continuous synchronization process.

Args:
peer_selection_mode (str): The way how workers communicate with each other. Currently "all" is supported.
"all" means all workers' weights are synchronized during each communication.
sync_interval_ms (int): Number of milliseconds between model synchronizations.
"""

self.peer_selection_mode = peer_selection_mode
self.sync_interval_ms = sync_interval_ms
self.stop_event = threading.Event()
check_nccl_proto()

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
self.tensors = [
param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
for name, param in parameters.__reversed__()
]
return self.tensors

def init_forward_pre_hook(self, bagua_module: BaguaModule):
def hook(input):
if not hasattr(self, "worker"):
self.worker = threading.Thread(
target=self.run_async_loop, args=[bagua_module]
)
self.worker.start()

return hook
wangraying marked this conversation as resolved.
Show resolved Hide resolved

def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
pass

return hook
wangraying marked this conversation as resolved.
Show resolved Hide resolved

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
pass

return hook

def init_operations(
self, bagua_module: BaguaModule, bucket: BaguaBucket,
wangraying marked this conversation as resolved.
Show resolved Hide resolved
):
bucket.clear_ops()
bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode,
)

def abort(self, bagua_module: BaguaModule, grace_period_seconds=5):
"""
Gracefully stop all workers.

Args:
bagua_module: A PyTorch module initialized by ``with_bagua(...)`` method.
grace_period_seconds: Number of seconds a worker will wait before aborting its unfinished communication operations.
wangraying marked this conversation as resolved.
Show resolved Hide resolved
"""
assert (
self.worker.is_alive() # pytype: disable=attribute-error
), "cannot abort since the asynchronous communication thread is not started"
self.stop_event.set()
time.sleep(grace_period_seconds)
bagua_module._bagua_backend.global_communicator.abort()

self.worker.join() # pytype: disable=attribute-error

def run_async_loop(self, bagua_module: BaguaModule):
while not self.stop_event.is_set():
if bagua_module.training:
for bucket in bagua_module.bagua_buckets:
for tensor in bucket.tensors:
tensor.bagua_mark_communication_ready_without_synchronization()

bagua_module._bagua_backend.wait_pending_comm_ops()

time.sleep(self.sync_interval_ms / 1000)
20 changes: 20 additions & 0 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,26 @@ def append_low_precision_decentralized_synchronous_op(

return self

def append_asynchronous_model_average_op(self, peer_selection_mode: str):
"""
Append an asynchronous model average operation to a bucket. This operation will enable continuous model averaging between workers
while training a model.
The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.

Args:
peer_selection_mode (str): The way how workers communicate with each otehr. Currently "all" is supported.
"all" means all workers' weights are averaged during each communication.
"""

self.backend_bucket.append_decentralized_asynchronous_op(
self._bagua_backend.global_communicator,
None,
peer_selection_mode=peer_selection_mode,
torch_stream=torch.cuda.current_stream().cuda_stream,
)
return self

def clear_ops(self) -> BaguaBucket:
"""
Clear the previously appended operations.
Expand Down
80 changes: 80 additions & 0 deletions tests/torch_api/test_async_model_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tests.internal.common_utils import find_free_port
import unittest
import torch.multiprocessing as mp
import os
import bagua.torch_api as bagua


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 50, bias=True)
self.fc3 = nn.Linear(50, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, dim=1)


def run_model(rank):
# initialize subprocess env
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)

# init bagua distributed process group
torch.cuda.set_device(rank)
bagua.init_process_group()

# construct model and optimizer, etc.
model = Net().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# wrap model
algorithm = bagua.algorithms.async_model_average.AsyncModelAverageAlgorithm()
model = model.with_bagua(
[optimizer], algorithm
)
wangraying marked this conversation as resolved.
Show resolved Hide resolved

for _ in range(10):
data = torch.randn(4, 2).cuda()
target = torch.randn(4, 4).cuda()

optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)

loss.backward()
optimizer.step()

algorithm.abort(model)


class TestAsyncModelAverage(unittest.TestCase):
def test_algorithm(self):
if not torch.cuda.is_available():
print("skip tests since cuda is not available")
return

nprocs = torch.cuda.device_count()
os.environ["WORLD_SIZE"] = str(nprocs)
os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(find_free_port())
os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port())

mp.spawn(
run_model,
nprocs=nprocs,
)


if __name__ == "__main__":
unittest.main()