Skip to content

Commit

Permalink
[feat] multiproces_pipe: add checkpoint support (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines authored Mar 29, 2021
1 parent 9a95065 commit 5e6a7a5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
24 changes: 19 additions & 5 deletions fairscale/experimental/nn/multiprocess_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import Tensor
import torch.distributed.rpc as rpc
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential

Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
Expand Down Expand Up @@ -69,6 +70,12 @@ def _rcat(tensors: List) -> Tensor:
return torch.cat([t.local_value() for t in tensors])


def _rcheckpoint(rmodule: rpc.RRef, input_rref: rpc.RRef) -> TensorOrTensors:
module = rmodule.local_value()
input = module[0](input_rref) # calls _ToHere.forward
return checkpoint_sequential(module[1:], 1, input)


def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return [rpc.RRef(p) for p in module.local_value().parameters()]

Expand Down Expand Up @@ -159,8 +166,8 @@ def __init__(

if type(chunks) is not int or chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["never"]:
raise ValueError("checkpoint is not yet implemented")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
if deferred_batch_norm:
raise ValueError("deferred_batch_norm is not yet implemented")
if len(balance) != len(devices):
Expand All @@ -181,6 +188,9 @@ def __init__(
workers.append(worker)
rmodule.append(rlayer)

# The micro-batch index where the checkpointing stops.
self.checkpoint_stop = {"always": chunks, "except_last": chunks - 1, "never": 0}[checkpoint]

self.chunks = chunks
self.checkpoint = checkpoint
self.module = module
Expand All @@ -189,10 +199,14 @@ def __init__(

def forward(self, x: Tensor) -> rpc.RRef: # type: ignore
outputs = []
for chunk in x.chunk(self.chunks):
for i, chunk in enumerate(x.chunk(self.chunks)):
output = rpc.RRef(chunk)
for rlayer in self.rmodule:
output = rlayer.remote().forward(output)
if i < self.checkpoint_stop:
for rlayer in self.rmodule:
output = rpc.remote(rlayer.owner(), _rcheckpoint, args=(rlayer, output))
else:
for rlayer in self.rmodule:
output = rlayer.remote().forward(output)
outputs.append(output)
return rpc.remote(outputs[0].owner(), _rcat, args=(outputs,))

Expand Down
1 change: 1 addition & 0 deletions stubs/torch/utils/checkpoint.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ...
def check_backward_validity(inputs: Iterable[Any]): ...
def checkpoint_sequential(function: Module, segments: int, *args, **kwargs): ...
6 changes: 4 additions & 2 deletions tests/experimental/nn/test_multiprocess_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ def forward_chunks(devices):

@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def forward_multi(devices):
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def forward_multi(devices, checkpoint):
device = devices[0].split("/")[1]
torch.random.manual_seed(3)
torch.cuda.manual_seed_all(3)
x = torch.randn(8, 4).to(device)
x.requires_grad = True # TODO(msb) remove this limitation
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2], checkpoint=checkpoint)
if BOUNCE_TENSORS:
y = pipe(x).remote().cpu().to_here()
else:
Expand Down

0 comments on commit 5e6a7a5

Please sign in to comment.