From c96c2ea877d1a8d09c9345f9083ee4e4cca2ed9c Mon Sep 17 00:00:00 2001 From: Pieter-Jan Hoedt Date: Thu, 14 Jan 2021 11:42:51 +0100 Subject: [PATCH] fix bugs in original code --- experiments/sequential_mnist.py | 3 ++- experiments/simple_function_recurrent.py | 26 ++++++++++++------------ stable_nalu/dataset/sequential_mnist.py | 20 ++---------------- stable_nalu/writer/save_model.py | 9 ++++++++ stable_nalu/writer/summary_writer.py | 2 +- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/experiments/sequential_mnist.py b/experiments/sequential_mnist.py index a6e14a9..6089bc4 100644 --- a/experiments/sequential_mnist.py +++ b/experiments/sequential_mnist.py @@ -257,7 +257,8 @@ operation=args.operation, use_cuda=args.cuda, seed=args.seed, - mnist_digits=args.mnist_digits + mnist_digits=args.mnist_digits, + num_workers=2 # debugging ) dataset_train = dataset.fork(seq_length=args.interpolation_length, subset='train').dataloader(shuffle=True) # Seeds are from random.org diff --git a/experiments/simple_function_recurrent.py b/experiments/simple_function_recurrent.py index df9a05d..ed42e7c 100644 --- a/experiments/simple_function_recurrent.py +++ b/experiments/simple_function_recurrent.py @@ -56,7 +56,7 @@ print(f' - max_iterations: {args.max_iterations}') # Prepear logging -results_writer = stable_nalu.writer.ResultsWriter('simple_function_recurrent') +# results_writer = stable_nalu.writer.ResultsWriter('simple_function_recurrent') summary_writer = stable_nalu.writer.SummaryWriter( f'simple_function_recurrent/{args.layer_type.lower()}_{args.operation.lower()}_{args.seed}' ) @@ -138,15 +138,15 @@ def test_model(dataloader): print(f' - loss_valid_extra: {loss_valid_extra}') # save results -results_writer.add({ - 'seed': args.seed, - 'operation': args.operation, - 'layer_type': args.layer_type, - 'simple': args.simple, - 'cuda': args.cuda, - 'verbose': args.verbose, - 'max_iterations': args.max_iterations, - 'loss_train': loss_train, - 'loss_valid_inter': loss_valid_inter, - 'loss_valid_extra': loss_valid_extra -}) +#results_writer.add({ +# 'seed': args.seed, +# 'operation': args.operation, +# 'layer_type': args.layer_type, +# 'simple': args.simple, +# 'cuda': args.cuda, +# 'verbose': args.verbose, +# 'max_iterations': args.max_iterations, +# 'loss_train': loss_train, +# 'loss_valid_inter': loss_valid_inter, +# 'loss_valid_extra': loss_valid_extra +#}) diff --git a/stable_nalu/dataset/sequential_mnist.py b/stable_nalu/dataset/sequential_mnist.py index 480ebd6..457d0a6 100644 --- a/stable_nalu/dataset/sequential_mnist.py +++ b/stable_nalu/dataset/sequential_mnist.py @@ -16,7 +16,7 @@ class ItemShape(NamedTuple): class OPERATIONS: @staticmethod def sum(seq): - return OPERATIONS.sum(seq) + return np.sum(seq, keepdims=True).reshape(-1, 1) @staticmethod def cumsum(seq): @@ -55,25 +55,9 @@ def __init__(self, operation, self._rng = np.random.RandomState(seed) self._mnist_digits = set(mnist_digits) - def is_cum_task(): - if self._operation == OPERATIONS.sum: - return False - elif self._operation == OPERATIONS.cumsum: - return True - elif self._operation == OPERATIONS.prod: - return False - elif self._operation == OPERATIONS.cumprod: - return True - elif self._operation == OPERATIONS.div: - return False - elif self._operation == OPERATIONS.cumdiv: - return True - else: - raise ValueError('bad operation') - def get_item_shape(self): if self._operation == OPERATIONS.sum: - return ItemShape((None, 28, 28), (None, 1)) + return ItemShape((None, 28, 28), (1, 1)) elif self._operation == OPERATIONS.cumsum: return ItemShape((None, 28, 28), (None, 1)) elif self._operation == OPERATIONS.prod: diff --git a/stable_nalu/writer/save_model.py b/stable_nalu/writer/save_model.py index a26cffe..21c292e 100644 --- a/stable_nalu/writer/save_model.py +++ b/stable_nalu/writer/save_model.py @@ -13,4 +13,13 @@ def save_model(name, model): save_file = path.join(SAVE_DIR, name) + '.pth' os.makedirs(path.dirname(save_file), exist_ok=True) + + def remove_writer(m): + # remove writers from ExtendedTorchModules + try: + del m.writer + except AttributeError: + pass + + model.apply(remove_writer) torch.save(model, save_file) diff --git a/stable_nalu/writer/summary_writer.py b/stable_nalu/writer/summary_writer.py index e946f59..fdd0216 100644 --- a/stable_nalu/writer/summary_writer.py +++ b/stable_nalu/writer/summary_writer.py @@ -86,7 +86,7 @@ def add_tensor(self, name, matrix, verbose_only=True): def add_histogram(self, name, tensor, verbose_only=True): if torch.isnan(tensor).any(): print(f'nan detected in {self._namespace}/{name}') - tensor = torch.where(torch.isnan(tensor), torch.tensor(0, dtype=tensor.dtype), tensor) +# tensor = torch.where(torch.isnan(tensor), torch.tensor(0, dtype=tensor.dtype), tensor) raise ValueError('nan detected') if self.is_log_iteration() and self.is_logging_enabled() and self.is_verbose(verbose_only):