Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/allow multple backward passes #9

Merged
merged 3 commits into from
Jan 30, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x):

return torch.cat([y1, y2], dim=self.split_along_dim)

def backward_pass(self, y, dy):
def backward_pass(self, y, dy, retain_graph):
"""
Performs the backward pass of the reversible block.

Expand All @@ -61,6 +61,7 @@ def backward_pass(self, y, dy):

:param y: Outputs of the reversible block
:param dy: Derivatives of the outputs
:param retain_graph: Whether to retain the graph on intercepted backwards
:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
"""

Expand All @@ -86,7 +87,7 @@ def backward_pass(self, y, dy):

# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of G are set as a side effect
gy1.backward(dy2)
gy1.backward(dy2, retain_graph = retain_graph)

with torch.no_grad():
x2 = y2 - gy1 # Restore first input of forward()
Expand All @@ -106,7 +107,7 @@ def backward_pass(self, y, dy):

# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of F are set as a side effec
fx2.backward(dx1)
fx2.backward(dx1, retain_graph = retain_graph)

with torch.no_grad():
x1 = y1 - fx2 # Restore second input of forward()
Expand All @@ -131,7 +132,7 @@ class _ReversibleModuleFunction(torch.autograd.function.Function):
'''

@staticmethod
def forward(ctx, x, reversible_blocks):
def forward(ctx, x, reversible_blocks, eagerly_discard_variables):
'''
Performs the forward pass of a reversible sequence within the autograd framework
:param ctx: autograd context
Expand All @@ -145,6 +146,7 @@ def forward(ctx, x, reversible_blocks):
x = block(x)
ctx.y = x.detach() #not using ctx.save_for_backward(x) saves us memory by beeing able to free ctx.y earlier in the backward pass
ctx.reversible_blocks = reversible_blocks
ctx.eagerly_discard_variables = eagerly_discard_variables
return x

@staticmethod
Expand All @@ -156,11 +158,13 @@ def backward(ctx, dy):
:return: derivatives of the inputs
'''
y = ctx.y
del ctx.y
if ctx.eagerly_discard_variables:
del ctx.y
for i in range(len(ctx.reversible_blocks) - 1, -1, -1):
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
del ctx.reversible_blocks
return dy, None
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, not ctx.eagerly_discard_variables)
if ctx.eagerly_discard_variables:
del ctx.reversible_blocks
return dy, None, None

class ReversibleSequence(nn.Module):
'''
Expand All @@ -171,23 +175,26 @@ class ReversibleSequence(nn.Module):
the reversible sequece to save memory.

Arguments:
reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock whic
reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock
which are to be used in the reversible sequence.
eagerly_discard_variables (bool): If set to true backward() discards the variables requried for
calculating the gradient and therefore saves memory. Disable if you call backward() multiple times.
'''

def __init__(self, reversible_blocks):
def __init__(self, reversible_blocks, eagerly_discard_variables = True):
super(ReversibleSequence, self).__init__()
assert (isinstance(reversible_blocks, nn.ModuleList))
for block in reversible_blocks:
assert(isinstance(block, ReversibleBlock))

self.reversible_blocks = reversible_blocks
self.eagerly_discard_variables = eagerly_discard_variables

def forward(self, x):
'''
Forward pass of a reversible sequence
:param x: Input tensor
:return: Output tensor
'''
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
return x
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks, self.eagerly_discard_variables)
return x