Skip to content

Commit

Permalink
Merge pull request #9 from ixaxaar/hidden_layers
Browse files Browse the repository at this point in the history
Implement Hidden layers, small enhancements, cleanups
  • Loading branch information
Russi Chatterjee authored Oct 29, 2017
2 parents fc863a9 + 522a810 commit 9ebdb9c
Show file tree
Hide file tree
Showing 8 changed files with 499 additions and 37 deletions.
11 changes: 11 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
language: python
python:
- "3.6"
# command to install dependencies
install:
- pip install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
- pip install numpy
- pip install visdom
# command to run tests
script:
- pytest
74 changes: 43 additions & 31 deletions dnc/dnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
hidden_size,
rnn_type='lstm',
num_layers=1,
num_hidden_layers=2,
bias=True,
batch_first=True,
dropout=0,
Expand All @@ -41,6 +42,7 @@ def __init__(
self.hidden_size = hidden_size
self.rnn_type = rnn_type
self.num_layers = num_layers
self.num_hidden_layers = num_hidden_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
Expand All @@ -57,25 +59,34 @@ def __init__(
self.w = self.cell_size
self.r = self.read_heads

# input size of layer 0
self.layer0_input_size = self.r * self.w + self.input_size
# input size of subsequent layers
self.layern_input_size = self.r * self.w + self.hidden_size
# input size
self.nn_input_size = self.r * self.w + self.input_size
self.nn_output_size = self.r * self.w + self.hidden_size

self.interface_size = (self.w * self.r) + (3 * self.w) + (5 * self.r) + 3
self.output_size = self.hidden_size

self.rnns = []
self.rnns = [[None] * self.num_hidden_layers] * self.num_layers
self.memories = []

for layer in range(self.num_layers):
# controllers for each layer
if self.rnn_type.lower() == 'rnn':
self.rnns.append(nn.RNNCell(self.layer0_input_size, self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
elif self.rnn_type.lower() == 'gru':
self.rnns.append(nn.GRUCell(self.layer0_input_size, self.output_size, bias=self.bias))
elif self.rnn_type.lower() == 'lstm':
self.rnns.append(nn.LSTMCell(self.layer0_input_size, self.output_size, bias=self.bias))
for hlayer in range(self.num_hidden_layers):
if self.rnn_type.lower() == 'rnn':
if hlayer == 0:
self.rnns[layer][hlayer] = nn.RNNCell(self.nn_input_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
else:
self.rnns[layer][hlayer] = nn.RNNCell(self.output_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
elif self.rnn_type.lower() == 'gru':
if hlayer == 0:
self.rnns[layer][hlayer] = nn.GRUCell(self.nn_input_size, self.output_size, bias=self.bias)
else:
self.rnns[layer][hlayer] = nn.GRUCell(self.output_size, self.output_size, bias=self.bias)
elif self.rnn_type.lower() == 'lstm':
if hlayer == 0:
self.rnns[layer][hlayer] = nn.LSTMCell(self.nn_input_size, self.output_size, bias=self.bias)
else:
self.rnns[layer][hlayer] = nn.LSTMCell(self.output_size, self.output_size, bias=self.bias)

# memories for each layer
if not self.share_memory:
Expand Down Expand Up @@ -104,19 +115,20 @@ def __init__(
)

for layer in range(self.num_layers):
setattr(self, 'rnn_layer_' + str(layer), self.rnns[layer])
for hlayer in range(self.num_hidden_layers):
setattr(self, 'rnn_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
if not self.share_memory:
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
if self.share_memory:
setattr(self, 'rnn_layer_memory_shared', self.memories[0])

# final output layer
self.output_weights = nn.Linear(self.output_size, self.output_size)
self.mem_out = nn.Linear(self.layern_input_size, self.input_size)
self.mem_out = nn.Linear(self.nn_output_size, self.input_size)
self.dropout_layer = nn.Dropout(self.dropout)

if self.gpu_id != -1:
[x.cuda(self.gpu_id) for x in self.rnns]
[x.cuda(self.gpu_id) for y in self.rnns for x in y]
[x.cuda(self.gpu_id) for x in self.memories]
self.mem_out.cuda(self.gpu_id)

Expand All @@ -128,9 +140,11 @@ def _init_hidden(self, hx, batch_size, reset_experience):

# initialize hidden state of the controller RNN
if chx is None:
chx = cuda(T.zeros(self.num_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
chx = cuda(T.zeros(batch_size, self.output_size), gpu_id=self.gpu_id)
if self.rnn_type.lower() == 'lstm':
chx = (chx, chx)
chx = [ [ (chx.clone(), chx.clone()) for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
else:
chx = [ [ chx.clone() for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]

# Last read vectors
if last_read is None:
Expand Down Expand Up @@ -158,12 +172,19 @@ def _layer_forward(self, input, layer, hx=(None, None)):

for time in range(max_length):
# pass through controller
# print('input[time]', input[time].size(), self.layer0_input_size, self.layern_input_size)
chx = self.rnns[layer](input[time], chx)
layer_input = input[time]
hchx = []

for hlayer in range(self.num_hidden_layers):
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
hchx.append(h)
chx = hchx

# the interface vector
ξ = chx[0] if self.rnn_type.lower() == 'lstm' else chx
ξ = layer_input
# the output
out = self.output_weights(chx[0]) if self.rnn_type.lower() == 'lstm' else self.output_weights(chx)
out = self.output_weights(layer_input)

# pass through memory
if self.share_memory:
Expand Down Expand Up @@ -205,10 +226,9 @@ def forward(self, input, hx=(None, None, None), reset_experience=False):
# outs = [input[:, x, :] for x in range(max_length)]
outs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]

# chx = [x[0] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[0]
for layer in range(self.num_layers):
# this layer's hidden states
chx = [x[layer] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[layer]
chx = controller_hidden[layer]

m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
Expand Down Expand Up @@ -240,21 +260,13 @@ def forward(self, input, hx=(None, None, None), reset_experience=False):
if self.debug:
viz = T.cat(viz, 0).transpose(0, 1)

# final hidden values
if self.rnn_type.lower() == 'lstm':
h = T.stack([x[0] for x in chxs], 0)
c = T.stack([x[1] for x in chxs], 0)
controller_hidden = (h, c)
else:
controller_hidden = T.stack(chxs, 0)
controller_hidden = chxs

if not self.batch_first:
outputs = outputs.transpose(0, 1)
if is_packed:
outputs = pack(output, lengths)

# apply_dict(locals())

if self.debug:
return outputs, (controller_hidden, mem_hidden, read_vectors[-1]), viz
else:
Expand Down
6 changes: 3 additions & 3 deletions dnc/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def reset(self, batch_size=1, hidden=None, erase=True):

if hidden is None:
return {
'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id),
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.gpu_id),
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, r, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, 1, m).fill_(δ), gpu_id=self.gpu_id),
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
}
else:
Expand Down
11 changes: 8 additions & 3 deletions tasks/copy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
parser.add_argument('-input_size', type=int, default=6, help='dimension of input feature')
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
parser.add_argument('-dropout', type=float, default=0.3, help='controller dropout')
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')

parser.add_argument('-nlayer', type=int, default=2, help='number of layers')
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
parser.add_argument('-lr', type=float, default=1e-2, help='initial learning rate')
parser.add_argument('-clip', type=float, default=0.5, help='gradient clipping')

Expand Down Expand Up @@ -110,14 +111,17 @@ def criterion(predictions, targets):
rnn = DNC(
input_size=args.input_size,
hidden_size=args.nhid,
rnn_type='lstm',
rnn_type=args.rnn_type,
num_layers=args.nlayer,
num_hidden_layers=args.nhlayer,
dropout=args.dropout,
nr_cells=mem_slot,
cell_size=mem_size,
read_heads=read_heads,
gpu_id=args.cuda,
debug=True
)
print(rnn)

if args.cuda != -1:
rnn = rnn.cuda(args.cuda)
Expand Down Expand Up @@ -147,6 +151,7 @@ def criterion(predictions, targets):
# apply_dict(locals())
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]

Expand All @@ -166,7 +171,7 @@ def criterion(predictions, targets):
xtickstep=10,
ytickstep=2,
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
xlabel='mem_slot * time',
xlabel='mem_slot * layer',
ylabel='mem_size'
)
)
Expand Down
134 changes: 134 additions & 0 deletions test/test_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python3

import pytest
import numpy as np

import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
import torch.optim as optim
import numpy as np

import sys
import os
import math
import time
sys.path.append('./src/')
sys.path.insert(0, os.path.join('..', '..'))

from dnc.dnc import DNC
from test_utils import generate_data, criterion


def test_rnn_1():
T.manual_seed(1111)

input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 1
num_hidden_layers = 1
dropout = 0
nr_cells = 1
cell_size = 1
read_heads = 1
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 10
length = 10

rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)

optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()

input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()

output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)

loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10,1])


def test_rnn_n():
T.manual_seed(1111)

input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 12
cell_size = 17
read_heads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13

rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)

optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()

input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()

output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)

loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10,51])
Loading

0 comments on commit 9ebdb9c

Please sign in to comment.