Skip to content

Commit

Permalink
rebased commit (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran authored and eb8680 committed Nov 4, 2017
1 parent f20b82c commit 4b7d9f8
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 122 deletions.
13 changes: 4 additions & 9 deletions examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import math
import os
import sys
import time
import argparse
from functools import partial
from subprocess import check_call
from observations import multi_mnist
import numpy as np

import torch
Expand Down Expand Up @@ -106,13 +105,9 @@
pyro.set_rng_seed(args.seed)

# Load data.
infile = './data/multi_mnist_train_uint8.npz'
if not os.path.exists(infile):
print('Running multi_mnist.py to generate dataset at {}...'.format(infile))
multi_mnist_py = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'multi_mnist.py')
check_call([sys.executable, multi_mnist_py])
print('Finished running multi_mnist.py.')
X_np = np.load(infile)['x'].astype(np.float32)
inpath = './data'
(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)
X_np = X_np.astype(np.float32)
X_np /= 255.0
X = Variable(torch.from_numpy(X_np))
X_size = X.size(0)
Expand Down
86 changes: 0 additions & 86 deletions examples/air/multi_mnist.py

This file was deleted.

4 changes: 2 additions & 2 deletions examples/dmm/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyro.distributions.transformed_distribution import TransformedDistribution
import six.moves.cPickle as pickle
import polyphonic_data_loader as poly
from os.path import join, dirname, exists
from os.path import exists
import argparse
import time
from util import get_logger
Expand Down Expand Up @@ -258,7 +258,7 @@ def main(args):
log = get_logger(args.log)
log(args)

jsb_file_loc = join(dirname(__file__), "jsb_processed.pkl")
jsb_file_loc = "./data/jsb_processed.pkl"
# ingest training/validation/test data from disk
data = pickle.load(open(jsb_file_loc, "rb"))
training_seq_lengths = data['train']['sequence_lengths']
Expand Down
30 changes: 8 additions & 22 deletions examples/dmm/polyphonic_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,23 @@
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from os.path import join, dirname, exists, abspath
from six.moves.urllib.request import urlretrieve
from observations import jsb_chorales
from os.path import join, exists
import six.moves.cPickle as pickle
from pyro.util import ng_zeros


# this function downloads the raw data if it hasn't been already
def download_if_absent(saveas, url):

if not exists(saveas):
print("Couldn't find polyphonic music data at {}".format(saveas))
print("downloading polyphonic music data from %s..." % url)
urlretrieve(url, saveas)


# this function processes the raw data; in particular it unsparsifies it
def process_data(output="jsb_processed.pkl", rawdata="jsb_raw.pkl",
T_max=160, min_note=21, note_range=88):

def process_data(base_path, filename, T_max=160, min_note=21, note_range=88):
output = join(base_path, filename)
if exists(output):
return

print("processing raw polyphonic music data...")
data = pickle.load(open(rawdata, "rb"))
data = jsb_chorales(base_path)
processed_dataset = {}
for split in ['train', 'valid', 'test']:
for split, data_split in zip(['train', 'test', 'valid'], data):
processed_dataset[split] = {}
data_split = data[split]
n_seqs = len(data_split)
processed_dataset[split]['sequence_lengths'] = np.zeros((n_seqs), dtype=np.int32)
processed_dataset[split]['sequences'] = np.zeros((n_seqs, T_max, note_range))
Expand All @@ -61,11 +50,8 @@ def process_data(output="jsb_processed.pkl", rawdata="jsb_raw.pkl",


# this logic will be initiated upon import
base_loc = dirname(abspath(__file__))
raw_file = join(base_loc, "jsb_raw.pkl")
out_file = join(base_loc, "jsb_processed.pkl")
download_if_absent(raw_file, "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle")
process_data(output=out_file, rawdata=raw_file)
base_path = './data'
process_data(base_path, "jsb_processed.pkl")


# this function takes a numpy mini-batch and reverses each sequence
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'cloudpickle>=0.3.1',
'graphviz>=0.8',
'networkx>=2.0.0',
'observations>=0.1.4',
'torch',
'six>=1.10.0',
],
Expand Down
8 changes: 6 additions & 2 deletions tutorial/source/air.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"source": [
"%pylab inline\n",
"from collections import namedtuple\n",
"from observations import multi_mnist\n",
"import pyro\n",
"import pyro.optim as optim\n",
"from pyro.infer import SVI\n",
Expand Down Expand Up @@ -55,8 +56,11 @@
},
"outputs": [],
"source": [
"fn = '../../examples/air/data/multi_mnist_train_uint8.npz'\n",
"mnist = Variable(torch.from_numpy(np.load(fn)['x'].astype(np.float32) / 255.))\n",
"inpath = '../../examples/air/data'\n",

This comment has been minimized.

Copy link
@kajbaf

kajbaf Oct 5, 2018

there is no data at the specified path. The folder is empty.

"(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)\n",
"X_np = X_np.astype(np.float32)\n",
"X_np /= 255.0\n",
"mnist = Variable(torch.from_numpy(X_np))\n",
"def show_images(imgs):\n",
" figure(figsize=(12,4))\n",
" for i, img in enumerate(imgs):\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorial/source/dmm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@
"metadata": {},
"outputs": [],
"source": [
"jsb_file_loc = join(dirname(__file__), \"jsb_processed.pkl\")\n",
"jsb_file_loc = \"./data/jsb_processed.pkl\"\n",
"data = pickle.load(open(jsb_file_loc, \"rb\"))\n",
"training_seq_lengths = data['train']['sequence_lengths']\n",
"training_data_sequences = data['train']['sequences']\n",
Expand Down

0 comments on commit 4b7d9f8

Please sign in to comment.