Skip to content

Commit

Permalink
Merge pull request #9 from RobinBruegger/feature/allow-multple-backwa…
Browse files Browse the repository at this point in the history
…rd-passes

Feature/allow multple backward passes
  • Loading branch information
RobinBruegger authored Jan 30, 2020
2 parents b9b7291 + 4745069 commit 8ad230e
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x):

return torch.cat([y1, y2], dim=self.split_along_dim)

def backward_pass(self, y, dy):
def backward_pass(self, y, dy, retain_graph):
"""
Performs the backward pass of the reversible block.
Expand All @@ -61,6 +61,7 @@ def backward_pass(self, y, dy):
:param y: Outputs of the reversible block
:param dy: Derivatives of the outputs
:param retain_graph: Whether to retain the graph on intercepted backwards
:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
"""

Expand All @@ -86,7 +87,7 @@ def backward_pass(self, y, dy):

# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of G are set as a side effect
gy1.backward(dy2)
gy1.backward(dy2, retain_graph = retain_graph)

with torch.no_grad():
x2 = y2 - gy1 # Restore first input of forward()
Expand All @@ -106,7 +107,7 @@ def backward_pass(self, y, dy):

# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of F are set as a side effec
fx2.backward(dx1)
fx2.backward(dx1, retain_graph = retain_graph)

with torch.no_grad():
x1 = y1 - fx2 # Restore second input of forward()
Expand All @@ -131,7 +132,7 @@ class _ReversibleModuleFunction(torch.autograd.function.Function):
'''

@staticmethod
def forward(ctx, x, reversible_blocks):
def forward(ctx, x, reversible_blocks, eagerly_discard_variables):
'''
Performs the forward pass of a reversible sequence within the autograd framework
:param ctx: autograd context
Expand All @@ -145,6 +146,7 @@ def forward(ctx, x, reversible_blocks):
x = block(x)
ctx.y = x.detach() #not using ctx.save_for_backward(x) saves us memory by beeing able to free ctx.y earlier in the backward pass
ctx.reversible_blocks = reversible_blocks
ctx.eagerly_discard_variables = eagerly_discard_variables
return x

@staticmethod
Expand All @@ -156,11 +158,13 @@ def backward(ctx, dy):
:return: derivatives of the inputs
'''
y = ctx.y
del ctx.y
if ctx.eagerly_discard_variables:
del ctx.y
for i in range(len(ctx.reversible_blocks) - 1, -1, -1):
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
del ctx.reversible_blocks
return dy, None
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, not ctx.eagerly_discard_variables)
if ctx.eagerly_discard_variables:
del ctx.reversible_blocks
return dy, None, None

class ReversibleSequence(nn.Module):
'''
Expand All @@ -171,23 +175,26 @@ class ReversibleSequence(nn.Module):
the reversible sequece to save memory.
Arguments:
reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock whic
reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock
which are to be used in the reversible sequence.
eagerly_discard_variables (bool): If set to true backward() discards the variables requried for
calculating the gradient and therefore saves memory. Disable if you call backward() multiple times.
'''

def __init__(self, reversible_blocks):
def __init__(self, reversible_blocks, eagerly_discard_variables = True):
super(ReversibleSequence, self).__init__()
assert (isinstance(reversible_blocks, nn.ModuleList))
for block in reversible_blocks:
assert(isinstance(block, ReversibleBlock))

self.reversible_blocks = reversible_blocks
self.eagerly_discard_variables = eagerly_discard_variables

def forward(self, x):
'''
Forward pass of a reversible sequence
:param x: Input tensor
:return: Output tensor
'''
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
return x
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks, self.eagerly_discard_variables)
return x

0 comments on commit 8ad230e

Please sign in to comment.