diff --git a/README.md b/README.md index 6350f2f..02c2c68 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,7 @@ class PartiallyReversibleNet(nn.Module): blocks.append(rv.ReversibleBlock(f_func, g_func)) #pack all reversible blocks into a reversible sequence - #eagerly discard variables to save even more memory, if you are sure you will not backward more than once - self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks), eagerly_discard_variables = True) + self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks)) #non-reversible convolution to get to 10 channels (one for each label) self.conv2 = nn.Conv2d(32, 10, 3) diff --git a/revtorch/revtorch.py b/revtorch/revtorch.py index 474e3f3..5785847 100644 --- a/revtorch/revtorch.py +++ b/revtorch/revtorch.py @@ -161,7 +161,7 @@ def backward(ctx, dy): if ctx.eagerly_discard_variables: del ctx.y for i in range(len(ctx.reversible_blocks) - 1, -1, -1): - y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, ctx.multiple_backwards) + y, dy = ctx.reversible_blocks[i].backward_pass(y, dy, not ctx.eagerly_discard_variables) if ctx.eagerly_discard_variables: del ctx.reversible_blocks return dy, None, None @@ -175,9 +175,10 @@ class ReversibleSequence(nn.Module): the reversible sequece to save memory. Arguments: - reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock whic + reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock which are to be used in the reversible sequence. - eagerly_discard_variables (bool): Should the module eagerly discard the output and not retain the graph for the individual backwards called on the reversible blocks, for further memory savings + eagerly_discard_variables (bool): If set to true backward() discards the variables requried for + calculating the gradient and therefore saves memory. Disable if you call backward() multiple times. ''' def __init__(self, reversible_blocks, eagerly_discard_variables = True):