Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add async model average algorithm #110

Merged
merged 34 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
27411a9
chore(release): v0.6.2
NOBLES5E Jul 2, 2021
1fa5f1f
Merge branch 'master' of github.com:BaguaSys/bagua
NOBLES5E Jul 2, 2021
a7e334a
Merge branch 'master' of github.com:BaguaSys/bagua
NOBLES5E Jul 7, 2021
69e8bf2
update
NOBLES5E Jul 7, 2021
de05ee6
Merge branch 'master' into feat/async-model-average
wangraying Jul 24, 2021
3e65319
update
wangraying Jul 25, 2021
ee40323
add async
wangraying Aug 6, 2021
adcb1e3
add comment
wangraying Aug 6, 2021
cb6ad87
Update bagua/torch_api/bucket.py
wangraying Aug 6, 2021
5b5f793
update
wangraying Aug 6, 2021
7196a3d
update
wangraying Aug 6, 2021
891fa9a
Merge branch 'master' into feat/async-model-average
wangraying Aug 6, 2021
7625f6f
Apply suggestions from code review
wangraying Aug 6, 2021
b4285c8
Apply suggestions from code review
wangraying Aug 6, 2021
6c3274b
.
wangraying Aug 6, 2021
9460589
fix
wangraying Aug 9, 2021
e97e17d
.
wangraying Aug 10, 2021
42a4143
Apply suggestions from code review
wangraying Aug 12, 2021
edc56ec
little change
wangraying Aug 13, 2021
1acd2d5
Merge branch 'feat/async-model-average' of https://github.com/BaguaSy…
wangraying Aug 13, 2021
05ab250
..
wangraying Aug 13, 2021
e3379ee
Update async_model_average.py
NOBLES5E Aug 19, 2021
c3676f7
Merge branch 'master' into feat/async-model-average
wangraying Aug 19, 2021
1c40edc
fix pytype
wangraying Aug 19, 2021
f550bcd
add
wangraying Aug 20, 2021
8dc05a2
fmt
wangraying Aug 20, 2021
28c85c3
fmt
wangraying Aug 20, 2021
af18148
.
wangraying Aug 20, 2021
1c33cf9
Update async_model_average.py
NOBLES5E Aug 23, 2021
7246854
Update async_model_average.py
NOBLES5E Aug 23, 2021
7dd0526
Update bucket.py
NOBLES5E Aug 23, 2021
27fa44b
Update async_model_average.py
NOBLES5E Aug 23, 2021
e8446d8
Apply suggestions from code review
wangraying Aug 23, 2021
cf7b15f
uu
wangraying Aug 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python3
import torch
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm
from bagua.torch_api.env import get_world_size
from typing import List
from bagua.torch_api.tensor import BaguaTensor
import threading
import time


class AsyncModelAverageAlgorithm(Algorithm):
def __init__(
self, peer_selection_mode: str = "all", sync_interval_ms: int = 500,
wangraying marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Create an instance of the
`AsyncModelAverage <https://bagua-tutorials.kwai-seattle.com/algorithms/async-model-average.html>`_
algorithm.

The currently implementation is experimental, and has some restrictions on the training scenarios.
Since with an async algorithm, each worker can be in different iterations, the current implementation
assumes the data are in an endless stream, and there is no concept of an "epoch".

Should call :func:`barrier` to stop async training. It will cancel all unfinished communication operations.

Args:
peer_selection_mode (str): The way how a worker communicate with its peers. Currently "all" is supported.
"all" means all workers' weights are averaged in each communication step.
sync_interval_ms (int): How many milliseconds between two model synchronizations.
"""

self.peer_selection_mode = peer_selection_mode
self.sync_interval_ms = sync_interval_ms
self.stop_event = threading.Event()

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
self.tensors = [
param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
for name, param in parameters.__reversed__()
]
return self.tensors

def init_forward_pre_hook(self, bagua_module: BaguaModule):
def hook(input):
if not hasattr(self, "worker"):
self.worker = threading.Thread(
target=self.run_async_loop, args=[bagua_module]
)
self.worker.start()

return hook
wangraying marked this conversation as resolved.
Show resolved Hide resolved

def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
pass

return hook
wangraying marked this conversation as resolved.
Show resolved Hide resolved

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
pass

return hook

def init_operations(
self, bagua_module: BaguaModule, bucket: BaguaBucket,
wangraying marked this conversation as resolved.
Show resolved Hide resolved
):
bucket.clear_ops()
bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode,
)

def barrier(self, bagua_module: BaguaModule, stop_grace_period_secs=5):
"""
Gracefully stop all workers.

Args:
bagua_module: A PyTorch module initialized by ``with_bagua(...)`` method.
stop_grace_period_secs: How many seconds a worker will wait before aborting its unfinished communication operations.
"""
self.stop_event.set()
time.sleep(stop_grace_period_secs)
bagua_module._bagua_backend.global_communicator.abort()

if hasattr(self, "worker"):
self.worker.join()
else:
raise RuntimeError(
"Could not barrier model since background communication thread has not started"
)

def run_async_loop(self, bagua_module: BaguaModule):
step = 0
while not self.stop_event.is_set():
if bagua_module.training:
for bucket in bagua_module.bagua_buckets:
for tensor in bucket.tensors:
tensor.bagua_mark_communication_ready()

bagua_module._bagua_backend.wait_pending_comm_ops()

time.sleep(self.sync_interval_ms / 1000)
step += 1
21 changes: 21 additions & 0 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ def append_low_precision_decentralized_synchronous_op(

return self

def append_asynchronous_model_average_op(self, peer_selection_mode: str):
"""
Append an asynchronous model average operation to a bucket. This operation will enable model averaging simultaneously
with gradient computation and gradient update.

The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.

Args:
peer_selection_mode (str): The way how a worker communicate with its peers. Currently "all" is supported.
"all" means all workers' weights are averaged in each communication step.
"""

self.backend_bucket.append_decentralized_asynchronous_op(
self._bagua_backend.global_communicator,
None,
peer_selection_mode=peer_selection_mode,
torch_stream=torch.cuda.current_stream().cuda_stream,
)
return self

def clear_ops(self) -> BaguaBucket:
"""
Clear the previously appended operations.
Expand Down