From 4b7d9f83a688e6b8bb9dd598d2d36909582f1f3a Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Fri, 3 Nov 2017 17:21:43 -0700 Subject: [PATCH] rebased commit (#500) --- examples/air/main.py | 13 ++-- examples/air/multi_mnist.py | 86 -------------------------- examples/dmm/dmm.py | 4 +- examples/dmm/polyphonic_data_loader.py | 30 +++------ setup.py | 1 + tutorial/source/air.ipynb | 8 ++- tutorial/source/dmm.ipynb | 2 +- 7 files changed, 22 insertions(+), 122 deletions(-) delete mode 100644 examples/air/multi_mnist.py diff --git a/examples/air/main.py b/examples/air/main.py index 8d447016b5..df0afba69b 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -8,11 +8,10 @@ import math import os -import sys import time import argparse from functools import partial -from subprocess import check_call +from observations import multi_mnist import numpy as np import torch @@ -106,13 +105,9 @@ pyro.set_rng_seed(args.seed) # Load data. -infile = './data/multi_mnist_train_uint8.npz' -if not os.path.exists(infile): - print('Running multi_mnist.py to generate dataset at {}...'.format(infile)) - multi_mnist_py = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'multi_mnist.py') - check_call([sys.executable, multi_mnist_py]) - print('Finished running multi_mnist.py.') -X_np = np.load(infile)['x'].astype(np.float32) +inpath = './data' +(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42) +X_np = X_np.astype(np.float32) X_np /= 255.0 X = Variable(torch.from_numpy(X_np)) X_size = X.size(0) diff --git a/examples/air/multi_mnist.py b/examples/air/multi_mnist.py deleted file mode 100644 index e0afaa160c..0000000000 --- a/examples/air/multi_mnist.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -This script generates a dataset similar to the Multi-MNIST dataset -described in [1]. - -[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene -understanding with generative models." Advances in Neural Information -Processing Systems. 2016. -""" - -import os -import numpy as np -import torch -import torchvision.datasets as dset -from scipy.misc import imresize - - -def sample_one(canvas_size, mnist): - i = np.random.randint(mnist['digits'].shape[0]) - digit = mnist['digits'][i] - label = mnist['labels'][i] - scale = 0.1 * np.random.randn() + 1.3 - resized = imresize(digit, 1. / scale) - w = resized.shape[0] - assert w == resized.shape[1] - padding = canvas_size - w - pad_l = np.random.randint(0, padding) - pad_r = np.random.randint(0, padding) - pad_width = ((pad_l, padding - pad_l), (pad_r, padding - pad_r)) - positioned = np.pad(resized, pad_width, 'constant', constant_values=0) - return positioned, label - - -def sample_multi(num_digits, canvas_size, mnist): - canvas = np.zeros((canvas_size, canvas_size)) - labels = [] - for _ in range(num_digits): - positioned_digit, label = sample_one(canvas_size, mnist) - canvas += positioned_digit - labels.append(label) - # Crude check for overlapping digits. - if np.max(canvas) > 255: - return sample_multi(num_digits, canvas_size, mnist) - else: - return canvas, labels - - -def mk_dataset(n, mnist, max_digits, canvas_size): - x = [] - y = [] - for _ in range(n): - num_digits = np.random.randint(max_digits + 1) - canvas, labels = sample_multi(num_digits, canvas_size, mnist) - x.append(canvas) - y.append(labels) - return np.array(x, dtype=np.uint8), y - - -def load_mnist(): - loader = torch.utils.data.DataLoader( - dset.MNIST( - root='./data', - train=True, - download=True)) - return { - 'digits': loader.dataset.train_data.cpu().numpy(), - 'labels': loader.dataset.train_labels - } - - -# Generate the training set and dump it to disk. (Note, this will -# always generate the same data, else error out.) -def main(): - outfile = './data/multi_mnist_train_uint8.npz' - if os.path.exists(outfile): - print('Output file "{}" already exists. Quiting...'.format(outfile)) - return - np.random.seed(681307) - mnist = load_mnist() - x, y = mk_dataset(60000, mnist, 2, 50) - assert x.sum() == 884438093, 'Did not generate expected data.' - with open(outfile, 'wb') as f: - np.savez_compressed(f, x=x, y=y) - - -if __name__ == "__main__": - main() diff --git a/examples/dmm/dmm.py b/examples/dmm/dmm.py index 84b4121157..9c7b40a1af 100644 --- a/examples/dmm/dmm.py +++ b/examples/dmm/dmm.py @@ -24,7 +24,7 @@ from pyro.distributions.transformed_distribution import TransformedDistribution import six.moves.cPickle as pickle import polyphonic_data_loader as poly -from os.path import join, dirname, exists +from os.path import exists import argparse import time from util import get_logger @@ -258,7 +258,7 @@ def main(args): log = get_logger(args.log) log(args) - jsb_file_loc = join(dirname(__file__), "jsb_processed.pkl") + jsb_file_loc = "./data/jsb_processed.pkl" # ingest training/validation/test data from disk data = pickle.load(open(jsb_file_loc, "rb")) training_seq_lengths = data['train']['sequence_lengths'] diff --git a/examples/dmm/polyphonic_data_loader.py b/examples/dmm/polyphonic_data_loader.py index 6b32ed4c26..aeaea9ee2d 100644 --- a/examples/dmm/polyphonic_data_loader.py +++ b/examples/dmm/polyphonic_data_loader.py @@ -17,34 +17,23 @@ import torch.nn as nn from torch.autograd import Variable import numpy as np -from os.path import join, dirname, exists, abspath -from six.moves.urllib.request import urlretrieve +from observations import jsb_chorales +from os.path import join, exists import six.moves.cPickle as pickle from pyro.util import ng_zeros -# this function downloads the raw data if it hasn't been already -def download_if_absent(saveas, url): - - if not exists(saveas): - print("Couldn't find polyphonic music data at {}".format(saveas)) - print("downloading polyphonic music data from %s..." % url) - urlretrieve(url, saveas) - - # this function processes the raw data; in particular it unsparsifies it -def process_data(output="jsb_processed.pkl", rawdata="jsb_raw.pkl", - T_max=160, min_note=21, note_range=88): - +def process_data(base_path, filename, T_max=160, min_note=21, note_range=88): + output = join(base_path, filename) if exists(output): return print("processing raw polyphonic music data...") - data = pickle.load(open(rawdata, "rb")) + data = jsb_chorales(base_path) processed_dataset = {} - for split in ['train', 'valid', 'test']: + for split, data_split in zip(['train', 'test', 'valid'], data): processed_dataset[split] = {} - data_split = data[split] n_seqs = len(data_split) processed_dataset[split]['sequence_lengths'] = np.zeros((n_seqs), dtype=np.int32) processed_dataset[split]['sequences'] = np.zeros((n_seqs, T_max, note_range)) @@ -61,11 +50,8 @@ def process_data(output="jsb_processed.pkl", rawdata="jsb_raw.pkl", # this logic will be initiated upon import -base_loc = dirname(abspath(__file__)) -raw_file = join(base_loc, "jsb_raw.pkl") -out_file = join(base_loc, "jsb_processed.pkl") -download_if_absent(raw_file, "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle") -process_data(output=out_file, rawdata=raw_file) +base_path = './data' +process_data(base_path, "jsb_processed.pkl") # this function takes a numpy mini-batch and reverses each sequence diff --git a/setup.py b/setup.py index e481f11863..dfd5694e38 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ 'cloudpickle>=0.3.1', 'graphviz>=0.8', 'networkx>=2.0.0', + 'observations>=0.1.4', 'torch', 'six>=1.10.0', ], diff --git a/tutorial/source/air.ipynb b/tutorial/source/air.ipynb index bdd96121d2..31344636c3 100644 --- a/tutorial/source/air.ipynb +++ b/tutorial/source/air.ipynb @@ -26,6 +26,7 @@ "source": [ "%pylab inline\n", "from collections import namedtuple\n", + "from observations import multi_mnist\n", "import pyro\n", "import pyro.optim as optim\n", "from pyro.infer import SVI\n", @@ -55,8 +56,11 @@ }, "outputs": [], "source": [ - "fn = '../../examples/air/data/multi_mnist_train_uint8.npz'\n", - "mnist = Variable(torch.from_numpy(np.load(fn)['x'].astype(np.float32) / 255.))\n", + "inpath = '../../examples/air/data'\n", + "(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)\n", + "X_np = X_np.astype(np.float32)\n", + "X_np /= 255.0\n", + "mnist = Variable(torch.from_numpy(X_np))\n", "def show_images(imgs):\n", " figure(figsize=(12,4))\n", " for i, img in enumerate(imgs):\n", diff --git a/tutorial/source/dmm.ipynb b/tutorial/source/dmm.ipynb index c35a5995ee..cbc4ecd895 100644 --- a/tutorial/source/dmm.ipynb +++ b/tutorial/source/dmm.ipynb @@ -596,7 +596,7 @@ "metadata": {}, "outputs": [], "source": [ - "jsb_file_loc = join(dirname(__file__), \"jsb_processed.pkl\")\n", + "jsb_file_loc = \"./data/jsb_processed.pkl\"\n", "data = pickle.load(open(jsb_file_loc, \"rb\"))\n", "training_seq_lengths = data['train']['sequence_lengths']\n", "training_data_sequences = data['train']['sequences']\n",