Skip to content

Commit

Permalink
Merge branch 'lucidrains-master' into feature/allow-multple-backward-…
Browse files Browse the repository at this point in the history
…passes
  • Loading branch information
RobinBruegger committed Jan 30, 2020
2 parents b9b7291 + e18cdc2 commit 69e4864
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class PartiallyReversibleNet(nn.Module):
blocks.append(rv.ReversibleBlock(f_func, g_func))

#pack all reversible blocks into a reversible sequence
self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks))
#eagerly discard variables to save even more memory, if you are sure you will not backward more than once
self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks), eagerly_discard_variables = True)

#non-reversible convolution to get to 10 channels (one for each label)
self.conv2 = nn.Conv2d(32, 10, 3)
Expand Down
28 changes: 17 additions & 11 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x):

return torch.cat([y1, y2], dim=self.split_along_dim)

def backward_pass(self, y, dy):
def backward_pass(self, y, dy, retain_graph):
"""
Performs the backward pass of the reversible block.
Expand All @@ -61,6 +61,7 @@ def backward_pass(self, y, dy):
:param y: Outputs of the reversible block
:param dy: Derivatives of the outputs
:param retain_graph: Whether to retain the graph on intercepted backwards
:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
"""

Expand All @@ -86,7 +87,7 @@ def backward_pass(self, y, dy):

# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of G are set as a side effect
gy1.backward(dy2)
gy1.backward(dy2, retain_graph = retain_graph)

with torch.no_grad():
x2 = y2 - gy1 # Restore first input of forward()
Expand All @@ -106,7 +107,7 @@ def backward_pass(self, y, dy):

# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of F are set as a side effec
fx2.backward(dx1)
fx2.backward(dx1, retain_graph = retain_graph)

with torch.no_grad():
x1 = y1 - fx2 # Restore second input of forward()
Expand All @@ -131,7 +132,7 @@ class _ReversibleModuleFunction(torch.autograd.function.Function):
'''

@staticmethod
def forward(ctx, x, reversible_blocks):
def forward(ctx, x, reversible_blocks, eagerly_discard_variables):
'''
Performs the forward pass of a reversible sequence within the autograd framework
:param ctx: autograd context
Expand All @@ -145,6 +146,7 @@ def forward(ctx, x, reversible_blocks):
x = block(x)
ctx.y = x.detach() #not using ctx.save_for_backward(x) saves us memory by beeing able to free ctx.y earlier in the backward pass
ctx.reversible_blocks = reversible_blocks
ctx.eagerly_discard_variables = eagerly_discard_variables
return x

@staticmethod
Expand All @@ -156,11 +158,13 @@ def backward(ctx, dy):
:return: derivatives of the inputs
'''
y = ctx.y
del ctx.y
if ctx.eagerly_discard_variables:
del ctx.y
for i in range(len(ctx.reversible_blocks) - 1, -1, -1):
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
del ctx.reversible_blocks
return dy, None
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, ctx.multiple_backwards)
if ctx.eagerly_discard_variables:
del ctx.reversible_blocks
return dy, None, None

class ReversibleSequence(nn.Module):
'''
Expand All @@ -173,21 +177,23 @@ class ReversibleSequence(nn.Module):
Arguments:
reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock whic
which are to be used in the reversible sequence.
eagerly_discard_variables (bool): Should the module eagerly discard the output and not retain the graph for the individual backwards called on the reversible blocks, for further memory savings
'''

def __init__(self, reversible_blocks):
def __init__(self, reversible_blocks, eagerly_discard_variables = True):
super(ReversibleSequence, self).__init__()
assert (isinstance(reversible_blocks, nn.ModuleList))
for block in reversible_blocks:
assert(isinstance(block, ReversibleBlock))

self.reversible_blocks = reversible_blocks
self.eagerly_discard_variables = eagerly_discard_variables

def forward(self, x):
'''
Forward pass of a reversible sequence
:param x: Input tensor
:return: Output tensor
'''
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
return x
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks, self.eagerly_discard_variables)
return x

0 comments on commit 69e4864

Please sign in to comment.