Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use observations and move data loading upstream #495

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import math
import os
import sys
import time
import argparse
from functools import partial
from subprocess import check_call
from observations import multi_mnist
import numpy as np

import torch
Expand Down Expand Up @@ -106,13 +105,9 @@
pyro.set_rng_seed(args.seed)

# Load data.
infile = './data/multi_mnist_train_uint8.npz'
if not os.path.exists(infile):
print('Running multi_mnist.py to generate dataset at {}...'.format(infile))
multi_mnist_py = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'multi_mnist.py')
check_call([sys.executable, multi_mnist_py])
print('Finished running multi_mnist.py.')
X_np = np.load(infile)['x'].astype(np.float32)
inpath = './data'
(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)
X_np = X_np.astype(np.float32)
X_np /= 255.0
X = Variable(torch.from_numpy(X_np))
X_size = X.size(0)
Expand Down
86 changes: 0 additions & 86 deletions examples/air/multi_mnist.py

This file was deleted.

4 changes: 2 additions & 2 deletions examples/dmm/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyro.distributions.transformed_distribution import TransformedDistribution
import six.moves.cPickle as pickle
import polyphonic_data_loader as poly
from os.path import join, dirname, exists
from os.path import exists
import argparse
import time
from util import get_logger
Expand Down Expand Up @@ -258,7 +258,7 @@ def main(args):
log = get_logger(args.log)
log(args)

jsb_file_loc = join(dirname(__file__), "jsb_processed.pkl")
jsb_file_loc = "./data/jsb_processed.pkl"
# ingest training/validation/test data from disk
data = pickle.load(open(jsb_file_loc, "rb"))
training_seq_lengths = data['train']['sequence_lengths']
Expand Down
30 changes: 8 additions & 22 deletions examples/dmm/polyphonic_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,23 @@
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from os.path import join, dirname, exists, abspath
from six.moves.urllib.request import urlretrieve
from observations import jsb_chorales
from os.path import join, exists
import six.moves.cPickle as pickle
from pyro.util import ng_zeros


# this function downloads the raw data if it hasn't been already
def download_if_absent(saveas, url):

if not exists(saveas):
print("Couldn't find polyphonic music data at {}".format(saveas))
print("downloading polyphonic music data from %s..." % url)
urlretrieve(url, saveas)


# this function processes the raw data; in particular it unsparsifies it
def process_data(output="jsb_processed.pkl", rawdata="jsb_raw.pkl",
T_max=160, min_note=21, note_range=88):

def process_data(base_path, filename, T_max=160, min_note=21, note_range=88):
output = join(base_path, filename)
if exists(output):
return

print("processing raw polyphonic music data...")
data = pickle.load(open(rawdata, "rb"))
data = jsb_chorales(base_path)
processed_dataset = {}
for split in ['train', 'valid', 'test']:
for split, data_split in zip(['train', 'test', 'valid'], data):
processed_dataset[split] = {}
data_split = data[split]
n_seqs = len(data_split)
processed_dataset[split]['sequence_lengths'] = np.zeros((n_seqs), dtype=np.int32)
processed_dataset[split]['sequences'] = np.zeros((n_seqs, T_max, note_range))
Expand All @@ -61,11 +50,8 @@ def process_data(output="jsb_processed.pkl", rawdata="jsb_raw.pkl",


# this logic will be initiated upon import
base_loc = dirname(abspath(__file__))
raw_file = join(base_loc, "jsb_raw.pkl")
out_file = join(base_loc, "jsb_processed.pkl")
download_if_absent(raw_file, "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle")
process_data(output=out_file, rawdata=raw_file)
base_path = './data'
process_data(base_path, "jsb_processed.pkl")


# this function takes a numpy mini-batch and reverses each sequence
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'cloudpickle>=0.3.1',
'graphviz>=0.8',
'networkx>=2.0.0',
'observations>=0.1.4',
'torch',
'six>=1.10.0',
],
Expand Down
8 changes: 6 additions & 2 deletions tutorial/source/air.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"source": [
"%pylab inline\n",
"from collections import namedtuple\n",
"from observations import multi_mnist\n",
"import pyro\n",
"import pyro.optim as optim\n",
"from pyro.infer import SVI\n",
Expand Down Expand Up @@ -55,8 +56,11 @@
},
"outputs": [],
"source": [
"fn = '../../examples/air/data/multi_mnist_train_uint8.npz'\n",
"mnist = Variable(torch.from_numpy(np.load(fn)['x'].astype(np.float32) / 255.))\n",
"inpath = '../../examples/air/data'\n",
"(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)\n",
"X_np = X_np.astype(np.float32)\n",
"X_np /= 255.0\n",
"X = Variable(torch.from_numpy(X_np))\n",
"def show_images(imgs):\n",
" figure(figsize=(12,4))\n",
" for i, img in enumerate(imgs):\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorial/source/dmm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@
},
"outputs": [],
"source": [
"jsb_file_loc = join(dirname(__file__), \"jsb_processed.pkl\")\n",
"jsb_file_loc = \"./data/jsb_processed.pkl\"\n",
"data = pickle.load(open(jsb_file_loc, \"rb\"))\n",
"training_seq_lengths = data['train']['sequence_lengths']\n",
"training_data_sequences = data['train']['sequences']\n",
Expand Down