Skip to content

Commit

Permalink
make dimension on which input is chunked and then concatted modifiable
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 8, 2020
1 parent df55b91 commit 76b77da
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@ class ReversibleBlock(nn.Module):
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''

def __init__(self, f_block, g_block):
def __init__(self, f_block, g_block, dim=1):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
self.dim = dim

def forward(self, x):
"""
Performs the forward pass of the reversible block. Does not record any gradients.
:param x: Input tensor. Must be splittable along dimension 1.
:return: Output tensor of the same shape as the input tensor
"""
x1, x2 = torch.chunk(x, 2, dim=1)
x1, x2 = torch.chunk(x, 2, dim=self.dim)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f_block(x2)
y2 = x2 + self.g_block(y1)

return torch.cat([y1, y2], dim=1)
return torch.cat([y1, y2], dim=self.dim)

def backward_pass(self, y, dy):
"""
Expand All @@ -47,11 +48,11 @@ def backward_pass(self, y, dy):
"""

# Split the arguments channel-wise
y1, y2 = torch.chunk(y, 2, dim=1)
y1, y2 = torch.chunk(y, 2, dim=self.dim)
del y
assert (not y1.requires_grad), "y1 must already be detached"
assert (not y2.requires_grad), "y2 must already be detached"
dy1, dy2 = torch.chunk(dy, 2, dim=1)
dy1, dy2 = torch.chunk(dy, 2, dim=self.dim)
del dy
assert (not dy1.requires_grad), "dy1 must not require grad"
assert (not dy2.requires_grad), "dy2 must not require grad"
Expand Down Expand Up @@ -100,8 +101,8 @@ def backward_pass(self, y, dy):
x2.grad = None

# Undo the channelwise split
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
x = torch.cat([x1, x2.detach()], dim=self.dim)
dx = torch.cat([dx1, dx2], dim=self.dim)

return x, dx

Expand Down

0 comments on commit 76b77da

Please sign in to comment.