From 76b77da25b4bc8ec228ff5dde1aaee4d33848233 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 8 Jan 2020 13:14:44 -0800 Subject: [PATCH] make dimension on which input is chunked and then concatted modifiable --- revtorch/revtorch.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/revtorch/revtorch.py b/revtorch/revtorch.py index b9645a2..9ef7dd5 100644 --- a/revtorch/revtorch.py +++ b/revtorch/revtorch.py @@ -15,10 +15,11 @@ class ReversibleBlock(nn.Module): g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape ''' - def __init__(self, f_block, g_block): + def __init__(self, f_block, g_block, dim=1): super(ReversibleBlock, self).__init__() self.f_block = f_block self.g_block = g_block + self.dim = dim def forward(self, x): """ @@ -26,13 +27,13 @@ def forward(self, x): :param x: Input tensor. Must be splittable along dimension 1. :return: Output tensor of the same shape as the input tensor """ - x1, x2 = torch.chunk(x, 2, dim=1) + x1, x2 = torch.chunk(x, 2, dim=self.dim) y1, y2 = None, None with torch.no_grad(): y1 = x1 + self.f_block(x2) y2 = x2 + self.g_block(y1) - return torch.cat([y1, y2], dim=1) + return torch.cat([y1, y2], dim=self.dim) def backward_pass(self, y, dy): """ @@ -47,11 +48,11 @@ def backward_pass(self, y, dy): """ # Split the arguments channel-wise - y1, y2 = torch.chunk(y, 2, dim=1) + y1, y2 = torch.chunk(y, 2, dim=self.dim) del y assert (not y1.requires_grad), "y1 must already be detached" assert (not y2.requires_grad), "y2 must already be detached" - dy1, dy2 = torch.chunk(dy, 2, dim=1) + dy1, dy2 = torch.chunk(dy, 2, dim=self.dim) del dy assert (not dy1.requires_grad), "dy1 must not require grad" assert (not dy2.requires_grad), "dy2 must not require grad" @@ -100,8 +101,8 @@ def backward_pass(self, y, dy): x2.grad = None # Undo the channelwise split - x = torch.cat([x1, x2.detach()], dim=1) - dx = torch.cat([dx1, dx2], dim=1) + x = torch.cat([x1, x2.detach()], dim=self.dim) + dx = torch.cat([dx1, dx2], dim=self.dim) return x, dx