From e18cdc2100e10b05d27435540c84d11088bc8206 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 29 Jan 2020 13:05:21 -0800 Subject: [PATCH 1/2] add eagerly_discard_variables flag for ReversibleSequence to be able to save the output of the sequence and to retain the graph for the individual backward passes for the reversible blocks --- README.md | 3 ++- revtorch/revtorch.py | 28 +++++++++++++++++----------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 02c2c68..6350f2f 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,8 @@ class PartiallyReversibleNet(nn.Module): blocks.append(rv.ReversibleBlock(f_func, g_func)) #pack all reversible blocks into a reversible sequence - self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks)) + #eagerly discard variables to save even more memory, if you are sure you will not backward more than once + self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks), eagerly_discard_variables = True) #non-reversible convolution to get to 10 channels (one for each label) self.conv2 = nn.Conv2d(32, 10, 3) diff --git a/revtorch/revtorch.py b/revtorch/revtorch.py index 3e58cb5..474e3f3 100644 --- a/revtorch/revtorch.py +++ b/revtorch/revtorch.py @@ -52,7 +52,7 @@ def forward(self, x): return torch.cat([y1, y2], dim=self.split_along_dim) - def backward_pass(self, y, dy): + def backward_pass(self, y, dy, retain_graph): """ Performs the backward pass of the reversible block. @@ -61,6 +61,7 @@ def backward_pass(self, y, dy): :param y: Outputs of the reversible block :param dy: Derivatives of the outputs + :param retain_graph: Whether to retain the graph on intercepted backwards :return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus. """ @@ -86,7 +87,7 @@ def backward_pass(self, y, dy): # Use autograd framework to differentiate the calculation. The # derivatives of the parameters of G are set as a side effect - gy1.backward(dy2) + gy1.backward(dy2, retain_graph = retain_graph) with torch.no_grad(): x2 = y2 - gy1 # Restore first input of forward() @@ -106,7 +107,7 @@ def backward_pass(self, y, dy): # Use autograd framework to differentiate the calculation. The # derivatives of the parameters of F are set as a side effec - fx2.backward(dx1) + fx2.backward(dx1, retain_graph = retain_graph) with torch.no_grad(): x1 = y1 - fx2 # Restore second input of forward() @@ -131,7 +132,7 @@ class _ReversibleModuleFunction(torch.autograd.function.Function): ''' @staticmethod - def forward(ctx, x, reversible_blocks): + def forward(ctx, x, reversible_blocks, eagerly_discard_variables): ''' Performs the forward pass of a reversible sequence within the autograd framework :param ctx: autograd context @@ -145,6 +146,7 @@ def forward(ctx, x, reversible_blocks): x = block(x) ctx.y = x.detach() #not using ctx.save_for_backward(x) saves us memory by beeing able to free ctx.y earlier in the backward pass ctx.reversible_blocks = reversible_blocks + ctx.eagerly_discard_variables = eagerly_discard_variables return x @staticmethod @@ -156,11 +158,13 @@ def backward(ctx, dy): :return: derivatives of the inputs ''' y = ctx.y - del ctx.y + if ctx.eagerly_discard_variables: + del ctx.y for i in range(len(ctx.reversible_blocks) - 1, -1, -1): - y, dy = ctx.reversible_blocks[i].backward_pass(y, dy) - del ctx.reversible_blocks - return dy, None + y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, ctx.multiple_backwards) + if ctx.eagerly_discard_variables: + del ctx.reversible_blocks + return dy, None, None class ReversibleSequence(nn.Module): ''' @@ -173,15 +177,17 @@ class ReversibleSequence(nn.Module): Arguments: reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock whic which are to be used in the reversible sequence. + eagerly_discard_variables (bool): Should the module eagerly discard the output and not retain the graph for the individual backwards called on the reversible blocks, for further memory savings ''' - def __init__(self, reversible_blocks): + def __init__(self, reversible_blocks, eagerly_discard_variables = True): super(ReversibleSequence, self).__init__() assert (isinstance(reversible_blocks, nn.ModuleList)) for block in reversible_blocks: assert(isinstance(block, ReversibleBlock)) self.reversible_blocks = reversible_blocks + self.eagerly_discard_variables = eagerly_discard_variables def forward(self, x): ''' @@ -189,5 +195,5 @@ def forward(self, x): :param x: Input tensor :return: Output tensor ''' - x = _ReversibleModuleFunction.apply(x, self.reversible_blocks) - return x \ No newline at end of file + x = _ReversibleModuleFunction.apply(x, self.reversible_blocks, self.eagerly_discard_variables) + return x From 47450699e47c9a2d4bd2d4453c1dc1c7c3913363 Mon Sep 17 00:00:00 2001 From: Robin Date: Thu, 30 Jan 2020 21:16:35 +0100 Subject: [PATCH 2/2] Integrated PR that allows for multiple backward passes --- README.md | 3 +-- revtorch/revtorch.py | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6350f2f..02c2c68 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,7 @@ class PartiallyReversibleNet(nn.Module): blocks.append(rv.ReversibleBlock(f_func, g_func)) #pack all reversible blocks into a reversible sequence - #eagerly discard variables to save even more memory, if you are sure you will not backward more than once - self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks), eagerly_discard_variables = True) + self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks)) #non-reversible convolution to get to 10 channels (one for each label) self.conv2 = nn.Conv2d(32, 10, 3) diff --git a/revtorch/revtorch.py b/revtorch/revtorch.py index 474e3f3..5785847 100644 --- a/revtorch/revtorch.py +++ b/revtorch/revtorch.py @@ -161,7 +161,7 @@ def backward(ctx, dy): if ctx.eagerly_discard_variables: del ctx.y for i in range(len(ctx.reversible_blocks) - 1, -1, -1): - y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, ctx.multiple_backwards) + y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, not ctx.eagerly_discard_variables) if ctx.eagerly_discard_variables: del ctx.reversible_blocks return dy, None, None @@ -175,9 +175,10 @@ class ReversibleSequence(nn.Module): the reversible sequece to save memory. Arguments: - reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock whic + reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock which are to be used in the reversible sequence. - eagerly_discard_variables (bool): Should the module eagerly discard the output and not retain the graph for the individual backwards called on the reversible blocks, for further memory savings + eagerly_discard_variables (bool): If set to true backward() discards the variables requried for + calculating the gradient and therefore saves memory. Disable if you call backward() multiple times. ''' def __init__(self, reversible_blocks, eagerly_discard_variables = True):