-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add async model average algorithm (#110)
- Loading branch information
1 parent
2658a51
commit 96c7bd5
Showing
4 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#!/usr/bin/env python3 | ||
from bagua.torch_api.bucket import BaguaBucket | ||
from bagua.torch_api.distributed import BaguaModule | ||
from bagua.torch_api.algorithms import Algorithm | ||
from typing import List | ||
from bagua.torch_api.tensor import BaguaTensor | ||
import threading | ||
import time | ||
import logging | ||
import os | ||
|
||
|
||
def check_nccl_proto(): | ||
# TODO: remove nccl proto check | ||
proto_str = os.environ.get("NCCL_PROTO", "") | ||
if ( | ||
proto_str == "" | ||
or ("^" not in proto_str and "LL128" in proto_str) # noqa: W503 | ||
or ("^" in proto_str and "LL128" not in proto_str) # noqa: W503 | ||
): | ||
logging.warn( | ||
"`LL128` proto for NCCL backend is not stable for async algorithms. Set `NCCL_PROTO=^LL128` to exclude it." | ||
) # TODO; remove this after https://github.com/NVIDIA/nccl/issues/549 gets solved | ||
|
||
|
||
class AsyncModelAverageAlgorithm(Algorithm): | ||
def __init__( | ||
self, peer_selection_mode: str = "all", sync_interval_ms: int = 500, | ||
): | ||
""" | ||
Create an instance of the | ||
`AsyncModelAverage <https://bagua-tutorials.kwai-seattle.com/algorithms/async-model-average.html>`_ | ||
algorithm. | ||
The asynchronous implementation is experimental, and imposes some restrictions. | ||
With such asynchronous algorithm, the number of iterations on each worker are different. Therefore | ||
the current implementation assumes that the dataset is an endless stream, and all workers continuously | ||
synchronize between each other. | ||
Users should call :func:`abort` to manually stop the algorithm's continuous synchronization process. | ||
Args: | ||
peer_selection_mode (str): The way how workers communicate with each other. Currently "all" is supported. | ||
"all" means all workers' weights are synchronized during each communication. | ||
sync_interval_ms (int): Number of milliseconds between model synchronizations. | ||
""" | ||
|
||
self.peer_selection_mode = peer_selection_mode | ||
self.sync_interval_ms = sync_interval_ms | ||
self.stop_event = threading.Event() | ||
check_nccl_proto() | ||
|
||
def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: | ||
parameters = bagua_module.bagua_build_params() | ||
self.tensors = [ | ||
param.ensure_bagua_tensor(name, bagua_module.bagua_module_name) | ||
for name, param in parameters.__reversed__() | ||
] | ||
return self.tensors | ||
|
||
def init_forward_pre_hook(self, bagua_module: BaguaModule): | ||
def hook(input): | ||
if not hasattr(self, "worker"): | ||
self.worker = threading.Thread( | ||
target=self.run_async_loop, args=[bagua_module] | ||
) | ||
self.worker.start() | ||
|
||
return hook | ||
|
||
def init_backward_hook(self, bagua_module: BaguaModule): | ||
def hook(parameter_name, parameter): | ||
pass | ||
|
||
return hook | ||
|
||
def init_post_backward_hook(self, bagua_module: BaguaModule): | ||
def hook(): | ||
pass | ||
|
||
return hook | ||
|
||
def init_operations( | ||
self, bagua_module: BaguaModule, bucket: BaguaBucket, | ||
): | ||
bucket.clear_ops() | ||
bucket.append_asynchronous_model_average_op( | ||
peer_selection_mode=self.peer_selection_mode, | ||
) | ||
|
||
def abort(self, bagua_module: BaguaModule, grace_period_seconds=5): | ||
""" | ||
Gracefully stop all workers. | ||
Args: | ||
bagua_module: A PyTorch module initialized by ``with_bagua(...)`` method. | ||
grace_period_seconds: Number of seconds a worker will wait before aborting its unfinished communication operations. | ||
""" | ||
assert ( | ||
self.worker.is_alive() # pytype: disable=attribute-error | ||
), "cannot abort since the asynchronous communication thread is not started" | ||
self.stop_event.set() | ||
time.sleep(grace_period_seconds) | ||
bagua_module._bagua_backend.global_communicator.abort() | ||
|
||
self.worker.join() # pytype: disable=attribute-error | ||
|
||
def run_async_loop(self, bagua_module: BaguaModule): | ||
while not self.stop_event.is_set(): | ||
if bagua_module.training: | ||
for bucket in bagua_module.bagua_buckets: | ||
for tensor in bucket.tensors: | ||
tensor.bagua_mark_communication_ready_without_synchronization() | ||
|
||
bagua_module._bagua_backend.wait_pending_comm_ops() | ||
|
||
time.sleep(self.sync_interval_ms / 1000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from tests.internal.common_utils import find_free_port | ||
import unittest | ||
import torch.multiprocessing as mp | ||
import os | ||
import bagua.torch_api as bagua | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.fc1 = nn.Linear(2, 10, bias=False) | ||
self.fc2 = nn.Linear(10, 50, bias=True) | ||
self.fc3 = nn.Linear(50, 4, bias=False) | ||
self.relu = nn.ReLU() | ||
|
||
def forward(self, x): | ||
x = self.relu(self.fc1(x)) | ||
x = self.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return F.softmax(x, dim=1) | ||
|
||
|
||
def run_model(rank): | ||
# initialize subprocess env | ||
os.environ["RANK"] = str(rank) | ||
os.environ["LOCAL_RANK"] = str(rank) | ||
|
||
# init bagua distributed process group | ||
torch.cuda.set_device(rank) | ||
bagua.init_process_group() | ||
|
||
# construct model and optimizer, etc. | ||
model = Net().cuda() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
loss_fn = nn.MSELoss() | ||
|
||
# wrap model | ||
algorithm = bagua.algorithms.async_model_average.AsyncModelAverageAlgorithm() | ||
model = model.with_bagua( | ||
[optimizer], algorithm | ||
) | ||
|
||
for _ in range(10): | ||
data = torch.randn(4, 2).cuda() | ||
target = torch.randn(4, 4).cuda() | ||
|
||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = loss_fn(output, target) | ||
|
||
loss.backward() | ||
optimizer.step() | ||
|
||
algorithm.abort(model) | ||
|
||
|
||
class TestAsyncModelAverage(unittest.TestCase): | ||
def test_algorithm(self): | ||
if not torch.cuda.is_available(): | ||
print("skip tests since cuda is not available") | ||
return | ||
|
||
nprocs = torch.cuda.device_count() | ||
os.environ["WORLD_SIZE"] = str(nprocs) | ||
os.environ["LOCAL_WORLD_SIZE"] = str(nprocs) | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = str(find_free_port()) | ||
os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port()) | ||
|
||
mp.spawn( | ||
run_model, | ||
nprocs=nprocs, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |