Skip to content

Commit

Permalink
fix bugs in original code
Browse files Browse the repository at this point in the history
  • Loading branch information
hoedt committed Jan 14, 2021
1 parent 97a930c commit c96c2ea
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 33 deletions.
3 changes: 2 additions & 1 deletion experiments/sequential_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@
operation=args.operation,
use_cuda=args.cuda,
seed=args.seed,
mnist_digits=args.mnist_digits
mnist_digits=args.mnist_digits,
num_workers=2 # debugging
)
dataset_train = dataset.fork(seq_length=args.interpolation_length, subset='train').dataloader(shuffle=True)
# Seeds are from random.org
Expand Down
26 changes: 13 additions & 13 deletions experiments/simple_function_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
print(f' - max_iterations: {args.max_iterations}')

# Prepear logging
results_writer = stable_nalu.writer.ResultsWriter('simple_function_recurrent')
# results_writer = stable_nalu.writer.ResultsWriter('simple_function_recurrent')
summary_writer = stable_nalu.writer.SummaryWriter(
f'simple_function_recurrent/{args.layer_type.lower()}_{args.operation.lower()}_{args.seed}'
)
Expand Down Expand Up @@ -138,15 +138,15 @@ def test_model(dataloader):
print(f' - loss_valid_extra: {loss_valid_extra}')

# save results
results_writer.add({
'seed': args.seed,
'operation': args.operation,
'layer_type': args.layer_type,
'simple': args.simple,
'cuda': args.cuda,
'verbose': args.verbose,
'max_iterations': args.max_iterations,
'loss_train': loss_train,
'loss_valid_inter': loss_valid_inter,
'loss_valid_extra': loss_valid_extra
})
#results_writer.add({
# 'seed': args.seed,
# 'operation': args.operation,
# 'layer_type': args.layer_type,
# 'simple': args.simple,
# 'cuda': args.cuda,
# 'verbose': args.verbose,
# 'max_iterations': args.max_iterations,
# 'loss_train': loss_train,
# 'loss_valid_inter': loss_valid_inter,
# 'loss_valid_extra': loss_valid_extra
#})
20 changes: 2 additions & 18 deletions stable_nalu/dataset/sequential_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ItemShape(NamedTuple):
class OPERATIONS:
@staticmethod
def sum(seq):
return OPERATIONS.sum(seq)
return np.sum(seq, keepdims=True).reshape(-1, 1)

@staticmethod
def cumsum(seq):
Expand Down Expand Up @@ -55,25 +55,9 @@ def __init__(self, operation,
self._rng = np.random.RandomState(seed)
self._mnist_digits = set(mnist_digits)

def is_cum_task():
if self._operation == OPERATIONS.sum:
return False
elif self._operation == OPERATIONS.cumsum:
return True
elif self._operation == OPERATIONS.prod:
return False
elif self._operation == OPERATIONS.cumprod:
return True
elif self._operation == OPERATIONS.div:
return False
elif self._operation == OPERATIONS.cumdiv:
return True
else:
raise ValueError('bad operation')

def get_item_shape(self):
if self._operation == OPERATIONS.sum:
return ItemShape((None, 28, 28), (None, 1))
return ItemShape((None, 28, 28), (1, 1))
elif self._operation == OPERATIONS.cumsum:
return ItemShape((None, 28, 28), (None, 1))
elif self._operation == OPERATIONS.prod:
Expand Down
9 changes: 9 additions & 0 deletions stable_nalu/writer/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,13 @@
def save_model(name, model):
save_file = path.join(SAVE_DIR, name) + '.pth'
os.makedirs(path.dirname(save_file), exist_ok=True)

def remove_writer(m):
# remove writers from ExtendedTorchModules
try:
del m.writer
except AttributeError:
pass

model.apply(remove_writer)
torch.save(model, save_file)
2 changes: 1 addition & 1 deletion stable_nalu/writer/summary_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def add_tensor(self, name, matrix, verbose_only=True):
def add_histogram(self, name, tensor, verbose_only=True):
if torch.isnan(tensor).any():
print(f'nan detected in {self._namespace}/{name}')
tensor = torch.where(torch.isnan(tensor), torch.tensor(0, dtype=tensor.dtype), tensor)
# tensor = torch.where(torch.isnan(tensor), torch.tensor(0, dtype=tensor.dtype), tensor)
raise ValueError('nan detected')

if self.is_log_iteration() and self.is_logging_enabled() and self.is_verbose(verbose_only):
Expand Down

0 comments on commit c96c2ea

Please sign in to comment.