Skip to content

Commit

Permalink
Add code to generate multi mnist dataset. (#1881)
Browse files Browse the repository at this point in the history
(Rather than depending on the deprecated observations library for
this.)
  • Loading branch information
null-a authored and neerajprad committed May 23, 2019
1 parent 9e9f8bc commit bd9e691
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 9 deletions.
5 changes: 2 additions & 3 deletions examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
import numpy as np
import torch
import visdom
from observations import multi_mnist

import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from air import AIR, latents_to_tensor

from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
from viz import draw_many, tensor_to_objs
Expand Down Expand Up @@ -113,7 +112,7 @@ def exp_decay(initial, final, begin, duration, t):

def load_data():
inpath = get_data_directory(__file__)
(X_np, Y), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)
X_np, Y = multi_mnist.load(inpath)
X_np = X_np.astype(np.float32)
X_np /= 255.0
X = torch.from_numpy(X_np)
Expand Down
91 changes: 91 additions & 0 deletions pyro/contrib/examples/multi_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
This script generates a dataset similar to the Multi-MNIST dataset
described in [1].
[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene
understanding with generative models." Advances in Neural Information
Processing Systems. 2016.
"""

import os

import numpy as np
from PIL import Image

from pyro.contrib.examples.util import get_data_loader


def imresize(arr, size):
return np.array(Image.fromarray(arr).resize(size))


def sample_one(canvas_size, mnist):
i = np.random.randint(mnist['digits'].shape[0])
digit = mnist['digits'][i]
label = mnist['labels'][i].item()
scale = 0.1 * np.random.randn() + 1.3
new_size = tuple(int(s / scale) for s in digit.shape)
resized = imresize(digit, new_size)
w = resized.shape[0]
assert w == resized.shape[1]
padding = canvas_size - w
pad_l = np.random.randint(0, padding)
pad_r = np.random.randint(0, padding)
pad_width = ((pad_l, padding - pad_l), (pad_r, padding - pad_r))
positioned = np.pad(resized, pad_width, 'constant', constant_values=0)
return positioned, label


def sample_multi(num_digits, canvas_size, mnist):
canvas = np.zeros((canvas_size, canvas_size))
labels = []
for _ in range(num_digits):
positioned_digit, label = sample_one(canvas_size, mnist)
canvas += positioned_digit
labels.append(label)
# Crude check for overlapping digits.
if np.max(canvas) > 255:
return sample_multi(num_digits, canvas_size, mnist)
else:
return canvas, labels


def mk_dataset(n, mnist, max_digits, canvas_size):
x = []
y = []
for _ in range(n):
num_digits = np.random.randint(max_digits + 1)
canvas, labels = sample_multi(num_digits, canvas_size, mnist)
x.append(canvas)
y.append(labels)
return np.array(x, dtype=np.uint8), y


def load_mnist(root_path):
loader = get_data_loader('MNIST', root_path)
return {
'digits': loader.dataset.data.cpu().numpy(),
'labels': loader.dataset.targets
}


def load(root_path):
file_path = os.path.join(root_path, 'multi_mnist_uint8.npz')
if os.path.exists(file_path):
data = np.load(file_path)
return data['x'], data['y']
else:
# Set RNG to known state.
rng_state = np.random.get_state()
np.random.seed(681307)
mnist = load_mnist(root_path)
print('Generating multi-MNIST dataset...')
x, y = mk_dataset(60000, mnist, 2, 50)
# Revert RNG state.
np.random.set_state(rng_state)
# Crude checksum.
# assert x.sum() == 883114919, 'Did not generate the expected data.'
with open(file_path, 'wb') as f:
np.savez_compressed(f, x=x, y=y)
print('Done!')
return x, y
14 changes: 8 additions & 6 deletions tutorial/source/air.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
"%pylab inline\n",
"import os\n",
"from collections import namedtuple\n",
"from observations import multi_mnist\n",
"import pyro\n",
"import pyro.optim as optim\n",
"from pyro.infer import SVI, TraceGraph_ELBO\n",
"import pyro.distributions as dist\n",
"import pyro.poutine as poutine\n",
"import pyro.contrib.examples.multi_mnist as multi_mnist\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.functional import relu, sigmoid, softplus, grid_sample, affine_grid\n",
Expand Down Expand Up @@ -63,18 +63,20 @@
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAABvCAYAAADfcqgvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFEBJREFUeJzt3X10z+Ufx/Hn1+ZmTYxtNoStECtsKtF9utUpKqEbpcahU2ed6oRyDkkqR1mpIznRoqJQ6iRJOW5S6hRRqrmpjSG3KTOWbPv98fld76EpZd/v5/vdXo9/Vttsl8tne3+v9/W+3legrKwMERER8U8NvwcgIiJS3SkYi4iI+EzBWERExGcKxiIiIj5TMBYREfGZgrGIiIjPFIxFRER8pmAsIiLiMwVjERERn0WH8psFAgG1+/qXysrKAv/1z2q+/70TmW/QnP8XesZDS/MdWsc731oZi4iI+EzBWERExGcKxiIiIj5TMBYREfGZgrGIiIjPFIxFRER8pmAsIiLis5CeMxaRypWens6mTZsA+PXXX30ejUj4O/XUUwG4+OKLmTNnDgCHDh0iLS0NgE6dOgFQs2ZN9u/fD8Dq1avZsWMHABs3bqSkpKTSx6VgLBKBatWqBcDAgQN5/vnngSODcb169bj//vsBmDt3LitXrgz9IEXCUN26dQFITk4mJSUFgDvvvJMrrrgCgKZNmwLw559/sm/fPgDy8/M5cOAAALNmzSInJ6fSx6U0tYiIiM+0MhaJQC7VlpaWRmlp6V8+3rlzZ7p16wbA5MmTQzo2kXD2ww8/APDbb7/x3HPPAdClSxe++eYbACZNmgTAypUr2bx5MwAZGRkkJycDUFxcHJRxKRiLRKBzzz0XgKioKLZt2/aXj7dt25aTTjoJgN9//z2kYxMJZ4cOHQK8n5HzzjsPgEWLFnHPPfcAsGfPHvvc888/H4Ds7GyysrIAb9snGJSmFhER8ZlWxiIRqF27dgBER0ezd+9ee78r7GrWrBk1a9YEoKxMF+2I/J06depw8sknA1jRVkZGBuPGjQOgsLCQrVu3BnUMER2Mo6KiAIiNjSU62vurFBcXW06/or00karABdoaNY5MbsXExADQokULS18fOnTIUtYJCQmAt192eBAXqW7y8/NZvHgxAFdddRWPP/44AB988AEAWVlZpKamApCZmRn0EwlKU4uIiPgsIlfGgYB3V7NL1T3yyCM0btwYgK+++spe7Xz33XeWWjh48GDoByoSJH/++Sfg/Sy4rFBpaamtjGvWrGnVnyNHjqRt27YA9rnZ2dksWrQo1MMWCRtr167l4YcfBryfi969ewNw6aWXApCYmMjo0aMBLKYEU0QHY/eL5eeff7YuRCkpKTz77LMAbNiwwSrf5s6dy8aNGwHtoUnkW7FiBQDXXXcdffr0AbwtmgsuuADwqq1dc4MzzzzTUmy5ubkArFu3LtRDFgk7Libk5ORw9tlnAxAfHw94+8gbNmwAoKioKOhjUZpaRETEZ4FQrhIDgUClfrPDV8juvxs2bMjpp58OwNVXX22ND7Zt28bUqVMBmDdvXsScvSwrKwv81z9b2fPtWseddtpptG/fHvAO0K9atQqA7du3V+a388WJzDdU/pwfy2mnnQZ4z3Lz5s0B2Lt3L1u2bAEgNTWVL7/8EoCbb77ZnvdwzAqF0zNeHWi+y7kK6kcffZQmTZoA3vYmwKBBg+znqUePHuzates/fY/jnW+tjEVERHwWkXvGjnuV74pZwFsBuxXaypUrWbJkCeA11B85ciQA55xzjrVBKygoCMvVQqhFR0cTFxcHeEfG3Jy2b9+eQYMGAViZf1JSEklJSYC3Gv7xxx8BePHFF/noo4+AI/9NpPLl5eUB0KtXL+rVqwfA/v377Zzx+PHjWbhwIeAdYxKRv7r22msBuOaaa6wD1xdffAF4nbjcZStnnXWW/W4LlogOxsfigmthYSELFiwAYP369fTq1QuAfv36WfX1I488Uq0Lu9w51bPPPpunn34a8KoICwoKAK8HsrtGzJ1bLSgosNR0amoqiYmJgHfziUvxuDmV4HBn6L/99tsj3l+/fn0ANm3aFDFbMSJ+SElJsQCcm5vL999/D5T3nn7ppZfIzMwEoGPHjkEPxkpTi4iI+KxKroydQCBgK7+NGzcyYcIEwGt3Nnz4cAD69+/PmDFjgNCUr4ebli1bAjBixAg7FgNYGnrs2LGWXVi7di1Q3i7uaHFxcdVyDsOJ68xVu3ZtdaAT+Rvt27e3LZ7hw4ezc+fOIz6ekpJihcEuUxhMVTIYu32ztLQ0MjIyAK/phzubOXv2bNv/7Nmzp+0rf/LJJz6M1l/9+vUD4MILL+T1118HYMKECbbPuGnTJktT/xPtTfrPNfpISUmx699E5K+6dOliTaGWLVtmfSvcAqVfv362demuXQwmpalFRER8ViVXxp06dQK8FKvrpnLyySeTn58PeCu/iRMnAt7NHHfddRdQPVfGp5xyCuAVr7lq6hYtWlinpuNdFUt4cMV0TZs2DfotMyKRrHbt2nbZUGZmprVXdvGjUaNGPProowBBvyQCqmAwrlWrFldddRXgBZIhQ4YA3m027ojOkCFDGDp0KOBVWbvKanezTXUKQG+++Sbg/d27du0KeGX87rjS4sWLrWL3008/Bbx0tPYjw5M7UrZmzRrdyiTyN6ZNm0bPnj0BuP32262aetasWQB8/vnnfP311yEbj9LUIiIiPovodpgVadCgAdnZ2YB3z/GAAQMArwK4devWALz88su28isqKrIiL1fMFE5nZIPdus5V38bGxlqj9AEDBthcNWjQwCoOXapmwoQJfPzxx0DVq0CPlHaYVYnaM4aW5ju0jne+q1yauqioiJ9++gnAelSD1yTB3ez02WefccsttwDwyy+/sHv3bsDruFLduLTmb7/9ZnvmS5Yssd7H7du3p3///kD5XkpOTo7tpUyZMqXKBWQRkVBTmlpERMRnVW5lfPDgQTug3alTJ2sPuHfvXlsFbty40aqIGzZsSE5Ojn2OeC0yD7/3dunSpQC0adMG8KrUR40aBXjZBFcEpn7UkcU1xImLi7Ptij179nDw4EH7eExMDAAlJSXWJlBEKl+VC8aAVcUNGDCAvn37AjBjxgxq164NQLNmzewXUUFBAdOmTfNnoGGqa9eudixm9erV1pPavX3ggQesQUh2djY///wz4KX/JXK4BiGTJk2iVatWAEyePJnFixcD3iXrV199NQCrVq2yn5Pq2MNdJNiUphYREfFZlVwZr1mzBoD33nuPu+++G4DLL7/cDninpaVZKm7atGn2+eIpLi5mxIgRgHer1fr16wFs/goLC60/dcOGDa3YSyvjyBEVFcU111wDeDd2uXZ/w4YNIysrC/Aa5TRo0ACAcePG+TNQkWqiSgZjt7eVk5Nj/XnT0tJsT/OLL76wAPLhhx9qL+woa9eutT6tI0eO5J133gG8X84At912m+0f5+fnW29viRz16tWjW7dugLdPfO+99wLeFo47aXDrrbfaz8/cuXOVnhYJIqWpRUREfFYlV8bO7t27mT17NuA1t3Cv7AOBgN3spDOyf7V161YGDx4MwPTp06093OFcY5Ts7Gy2b98e0vHJievcuTOXXnop4F2i7s7g5+bmWlvY66+/nilTpgDw5Zdf+jNQkf87/Pd2cnKyFeTGxMTQsGFDAMvo7d+/39r4FhYW+jDaf08rYxEREZ9VuXaYVY3frevatGnDbbfdBmBntj/77DO7G7qgoIA//vjjRL9N2Kjq7TCbN28OwKuvvsqpp54KQN++ffn8888Bry7AHVurX78+AwcOBLw6gmD9rvD7Ga9uImG+3ao3OTnZ6lMSExPtCF6nTp2sTW9cXJzdT+/+XH5+Pn369AEI6WUPFam27TClcuXm5jJ8+PAT+ho1a9bksssuA7zir9jYWKC80G7mzJm8++67JzZQOS6ucDE+Pp7x48cD3lly92/y0EMPWcC+//77WbduHaCzxRJ8gYAXs5o3b24V/eeffz5JSUmAd7OcWxC4JjUAf/zxh/VF+OabbwDvtrm8vLyQjb0yKE0tIiLiM62Mj8EVCjRp0oSUlBTAezW2c+dOwFsx6kjU8alVqxYdOnQAoHv37tStWxcob5+ZkJBg8z1v3jw7wyyVb/Xq1QAMHjyY5cuXA14Ro7sDfODAgTzzzDMALF26VPdWS8jUqVMHgKFDh1rnxJiYGDtel5eXZwW36enpxMXFAfDOO+8wZswYALZs2QLAgQMHrJdEpFAwrkBUVJRdJzhq1Ci6dOkCeGkSl7YbOXKkna/duXMniYmJAOzYsQOAQ4cOhXrYYauoqIiFCxcCcMYZZ9CyZUsA2rVrB3jtN5s2bQp4t0ctWLDAn4FWA7/++isACxYssJaw6enpPPHEE/Z+16u9pKTEn0FKteRekKelpdne79dff83YsWMBWLFiBQcOHAC8AJyeng54lf6uBXIkU5paRETEZ1oZV6Bu3brWRrNr1672/tLSUlvVPfXUU3z11VcAbN682W6BckUDixYtYtmyZYC3Mqzu6T5X0XjHHXdYdeSMGTMA785kl15y510luKKiosjIyABgzJgxtnUwadIkWz2LhJIrEjxw4ICdDX777beZP38+4BV4uer+1q1b2wU1K1eu9GG0lU/BuALp6elccskl9v9u72H9+vWWjm7RogVNmjQBvCDtUn4uPX3HHXcwb948AIYPH27BRrCmEu6AfmlpKT/99BNAxFVARqpzzjmHiRMnAtCoUSMeeughAJYvX17tXziKP9x+8IwZM+jYsSPgtWR1KeiioiL69+8PePvL7gjeqlWrKvx6rjo7NjaW/fv3A4T1s600tYiIiM+0Mq5At27dbAUM5bcRvfzyy9ZQPykpyVZ2h3Pvq1OnDt27dwdgypQpWhn/X1xcnM1L48aNAe/ijjfeeAMor7CW4HBZiV69etG6dWvAS1O///77gIq2xD/u2Zs/fz4XXXQRAH369LHq/q1bt9oW17p165gzZw6ArXqP5s7OX3vttfZ8h3P7YwVjsBSz20Pr0qWLBdWioiJLh3z77bdHpDlc56k5c+bwyy+/AF56GrymCuIJBAJ2FV9mZiaZmZlAecpo+vTpdswmnNNIVYE7wnTTTTcxc+ZMACZOnKjjZBI2du3aZQG4Ro0a3HjjjQCcfvrpRwTsbdu2/e3Xcc1rLr/8ctsyDGdKU4uIiPhMK2PK7+m98847AW9lHBUVBcCyZctYtGgRAPv27bMVXIcOHazyd+zYsbYydv1+e/ToEbLxh7vExEQGDRoEQFZWlrWye+WVVwCvYEPp0eCpUaMG5513HoA1R8jLy+Oxxx4DvJWISLgoLS21fg6jRo2ibdu2gFd06Cquk5OTrRnTmjVrKvw6rimIexvuFIwpr7o7usoXvFS0q6betWsXzz77rH2u+0fetm2bpaUbNWoUsnGHu5iYGAD69+9ve+3x8fF89NFHALzwwgsA/P777/4MsJqoVasWV155JVB+2cd9991Hfn6+j6MSOTa3XZWQkGDH7g6/QvGGG26wxdCIESMqrDVxX6OwsDAieqsrTS0iIuIzrYwP41bI7u3R/w2wfft2AEaPHm0r6R07dljLTFftFwgErI91OFfwBZO7evHuu+8mISEB8M5qv/baawCW2o+EV62R7ODBg0yfPh3AtlzcdotIOHJZtR49ethNY5s3b2bt2rUAdO7cmQsuuADwTmVU1CzI9S4YN27cMSuuw4lWxiIiIj7TyvgwboV2+ErtWKs2dyEEeHvMrhuX29MoLi5m9uzZAGzYsCEo4w1nKSkpdiQhKSmJH3/8EYAnn3yS9957Dyg/GibBVVpaSm5uLoC9FQlnLpOWnp5ul0Zs3brVjiidccYZJCcnA9CqVasKV8YuI7lnz56IyL4pGP+DVq1aWTXf9u3bj6j6dYVeXbt2tbOz7sGZN2+eNbKIhBRJZcvIyLB5Ky4u5q233gLg/fffr5bzISLHr6KFUWpqql2t2KBBA3sx706+HM3dI3DLLbcwevRoILy3DJWmFhER8ZlWxpS3YFy/fj3gpTVcx6g2bdrYq6qhQ4daU/KoqCg7uzlq1Cg6dOgAlF8qsWTJEis2iIQUSWUrLCy0FfApp5zChRdeCHhnitXtSUT+jit+/fTTT61QKyEhwdLXRUVFrF69GsDOJB9tz549gJfejoTfwYFQDjIQCITljLiKabcH8eCDD9oVirGxsXYT0w8//GABNjo62nr7tm3b1tLXS5YsAWDw4MH2sJyIsrKywD9/VsX8nO/ExEQGDBgAwLBhwywAz5o1y1LW7lrFcNo7PpH5hvB9xsNZpD7jkSqS5js+Pp6srCzAa/rhzg4vXLjQelMXFBRU2EbXtTmuXbs2xcXFgD8Lo+Odb6WpRUREfKaVcQWaNm1qxVcdO3a0DjCHCwQC9iqrpKTEzm8OGzYM8C68roxLDyLpVezRWrRoAcDUqVNp1qwZ4LWm27x5M4BVQJaUlLB3714AJk+ezNKlS30YrUcr49CL5Gc8Emm+Q+t451vBuALR0dG2B9y3b1969+4NeGlsl9Let2+fNQD55JNPyMnJAWDFihVA5V1FF8k/OC5NFB8fT2pqKgA9e/a0KxRda8aysjJ74TJkyBBmzJjhw2hxY1EwDrFIfsYjkeY7tJSmFhERiRBaGR+DWwE3atTI7jkeNGiQtcB844037LaQvLw8u+ygsu/jrWqvYuvUqWM3Wu3evRvwsghXXHEF4J3PXrZsmW/j08o49KraMx7uNN+hpTR1FaEfnNBSMA49PeOhpfkOLaWpRUREIoSCsYiIiM8UjEVERHymYCwiIuIzBWMRERGfKRiLiIj4TMFYRETEZyE9ZywiIiJ/pZWxiIiIzxSMRUREfKZgLCIi4jMFYxEREZ8pGIuIiPhMwVhERMRnCsYiIiI+UzAWERHxmYKxiIiIzxSMRUREfKZgLCIi4jMFYxEREZ8pGIuIiPhMwVhERMRnCsYiIiI+UzAWERHxmYKxiIiIzxSMRUREfKZgLCIi4jMFYxEREZ8pGIuIiPhMwVhERMRnCsYiIiI++x9GRfeHJefNbwAAAABJRU5ErkJggg==\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeAAAABvCAYAAAA0RRMsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAC7NJREFUeJzt3X9oldUfwPH3bEtDaqO5SpYRGTVKMCxDCpaRs18qRj+QKcsKUxdSwgjL0iR/tYp+/RNZVlhWUIZRRKwwIiPBlUmxrVxkac7KylZj2bb7/eNyn2Zfnb92n/Ps3vcLLtw9u8/O2eHc+7nn85xznoJUKoUkSYrXoNAVkCQpHxmAJUkKwAAsSVIABmBJkgIwAEuSFIABWJKkAAzAkiQFYACWJCkAA7AkSQEUxllYQUGB224doVQqVXC059reR+5Y2hts86NhH4+X7R2vvtrbEbAkSQEYgCVJCsAALElSAAZgSZICiHUSlnSkRo0aBcDkyZMZOXJk9Hzp0qUAPPXUU8HqJknHoiDO+wE7g+7I5fOMxdmzZ/PYY48BMHjw4P1+t3PnTgDOOOOMfi3TWdDxy+c+HoLtHS9nQUuSlDCOgBMuH7+trl69GoCpU6dSXFx8wNd0d3cD8O233wJQUVHRL2U7Ao5fPvbxkGzvePXV3gbghMu3N0tRURF///03AP/tmx9++CEA48ePj461t7cDUFJS0i/lG4Djl299vLdZs2Zx/vnnA3DXXXfFUmY+t3cIpqAlSUoYZ0EruBNOOIElS5YAMG3atOj47t27ue666wDYtGlTdDyTfgb44YcfYqql1H9WrFgBwMyZMykrKwPgu+++4/HHHw9ZLcXMFHTC5UO6qLKykg0bNkQ/t7S0AFBTU8PmzZuj40OHDgXgjz/+oKurC0h/gAG88sor/VIXU9Dxy4c+3lthYeEBL7NUVFSwbdu2rJefb+0dmiloSZISxhFwwuXyt9UFCxYAMGfOHEaMGBEdP+644w74+gkTJgDw3nvvuQ44h+RyH+9t3LhxAGzcuJFBg9Jjn56enuj3B+v3/S1f2jsp+mpvrwErmNraWgDKy8ujY3v37g1VHSmrPvvsMwCWLVvGwoULgXQKurGxMWS1FJApaEmSAnAErNhlZnqefvrpAOzZs4eqqioAtmzZcsjzBw0aREHBMWWKpdjt27cPgFWrVjFjxgwgfQklc1z5xwCsWD355JPMnTsX+HcG6Pr16w8r8GauofX09PDAAw9krY5SNj3yyCNs374dSAfgurq6wDVSKKagJUkKwBGwYpHZq3n69OnRDNDMns/z5s07rL8xefLk6LlpOw00mQzODTfcEL0HmpubaWtrC1ktBeQIWJKkABwBKxaZUW7vmyY88cQTAHR2dh7y/KqqKkaNGgWk74D01ltvZaGWUvZccMEFQHruQ2b97/XXXx9dD1YeSqVSsT2AlI8je+RCe5955pmp9vb2VHt7e6q7uzt6HMnfaGhoiM674447EtneSWrzgfTIh/aeNGlSqqOjI9XR0ZHq6upKjR49OjV69GjbOw8efbWnKWhJkgIwBa2smzhxIkOGDIl+/vHHHwEoLS0F0hOqMvf1hX8nbFVXV0c7BgF0dHQAsG7duqzXWepPO3bsYM+ePQAMHz6cL774InCNlATuBZ1wubJva+Z679y5c6M9bz///HMgvf3kjh07otfW1NQA+++T29nZyfr164F0YM4W94KOX6708b58//33+225Gte+zweSD+2dJN4NSZKkhHEEnHC59m21ubmZkSNHAkRrIf8rc7yzszO6b+q8efNYs2ZN1uvnCDh+udbHD2TNmjVMnDgRSF96KSwMd/UvH9o7SbwbkhKjoqKC22+/HYCTTz4ZgGHDhjF//vzoNffeey8AmzdvpqGhIf5KSv3s1Vdf5aKLLgJg165dgWujpDAFLUlSAKagE850UbxCpqC7u7sBmD17Ns8+++yxVGNAyZc+/vLLLwOwePFitm3bFqwe+dLeSdFXexuAE843S7ySEIBfe+21rM70Thr7eLxs73g5C1qSpIRxBHwIpaWlbN26FYDKykpaW1tjLd9vq/FKwggYwq4TjZt9PF62d7ycBX0MFi1axKpVqwBiD76SpNxlClqSpABMQR/E2WefDUBLS4vbxuURU9Dxs4/Hy/aOlynoo7Bo0aLQVZAk5TBT0JIkBTDgRsDjxo3j008/zWoZt912GzNmzADcNk6SlB2OgCVJCmDABODi4mKKi4t54YUXsl7WhAkTKCgooKCggKeffjrr5Um9ffDBB6GrICkGAyIFXVRUFK3FbWlpyXp51157bfT8999/z3p5Um9XXHFF6CpIisGAGQFLkpRLBsQ64F27dnHqqacCB7+Je39KpVL89ttvwL/3rA3FNXvxch3w/qqqqgCorq5m5syZQPr9AensUO/3R6bOPT09HMnnin08XrZ3vAbsOuCTTjoJSN+wPXMrrzj09PTw8MMPx1aeBEQz71966SX++usvAO68806A2G5POGbMGObPnw9ASUkJ11xzDQAdHR38+uuvABQUpD9PiouL9zv3/fffB6C+vp533303lvpKA5kpaEmSAkjsCHjt2rVMmzYNgLKyMvbs2ZP1MqdMmRI9f+edd7JentTbvn37AOjq6mLIkCEArFixAsj+CHjMmDEALFmyhPvvvx+ALVu29HlORUUFFRUVAGzYsIH77rsPgIaGhizWVModibsG/OCDDwJQV1dHV1cXACeeeOIh/3ZZWRmQ/hDbu3fvEdUrc25jYyMAy5cvT8zyI6/XxCvkNeCQ6uvrAWhqauL5558/rHMKCwujuRJNTU1cfPHFR1W2fTxetne8+mpvU9CSJAWQuBR0ZtIJwI033nhY5zzzzDNMnToVgD///JO7774bgNdff/2wzh8/fjwA5eXlANGaYylfZLJG9fX1XHbZZQAsWLCAtra2g56zevVqvvrqKwAmTZqU/UpKOSYRKehhw4YBsHHjxmjjiylTprB79+7/e+1ZZ53FQw89BMCVV14JQGlpKf/88w8AK1euZM6cOUB6FuehDB8+nHXr1gFEKbSkLAEB00Vxy9cUdMbQoUOjL7AzZsyIlhl98sknnHPOOUD6PQjwxhtvcNNNNx1zmfbxeNne8TIFLUlSwiRiBFxXVwek01+ZCVeZdZAZhYXpbHlra2s0Mq6pqQGgubmZwYMHA/DLL7+wdOlSgGik3Jerr76at99+e79jjoDzV76PgHs77bTT2LlzZ/RzZv1vZk3+LbfcEk2UPBb28XjZ3vFK/EYcmRmYjz766P8FXkjPjM6kxUpKSjjvvPMAmDVrFgC1tbU0NTVFv++9o9ChHH/88dHz2trao/sHpBzU1tYWLcebNGkSl1xyCUDWbwcq5QtT0JIkBZCIFPT27dsBGDt2LD/99FN0PLMVZWNjIyNHjgTS6ebM4v+MyspKPv7446zUOTTTRfEyBb2/TDbpm2++4Z577gHgzTff7Ncy7OPxsr3j5SQsSZISJhHXgEeMGAHApk2b+PLLL6PjY8eOBeCUU06JjpWXl3PrrbcC8PPPPwPk7OhXCiEz4bG2tpaPPvoIgJtvvpkLL7wwZLWknJOIAJwJsL3X7aZSKVpbW4F0gM7MaL700kuZPn06AIsXL465plJuq6ioiPZy/vrrr7n88suj3xmApf5lClqSpAASMQlLB+eEiXjl+ySs5557jquuugpIr5HfunUrAOeee260Dri5ublfy7SPx8v2jldf7W0ATjjfLPHK9wDc3d0dzXaur6/nxRdfBKC6upqioqKslGkfj5ftHS9nQUuSlDCOgBPOb6vxcgTcHa39bW9vjyZGrly5koULF2alTPt4vGzveCV+K0pJybB27Vqqq6uB9KqE5cuXA7Bs2bKQ1ZJykiloSZICMAWdcKaL4pXvKegQ7OPxsr3j5SQsSZISxgAsSVIABmBJkgIwAEuSFIABWJKkAAzAkiQFYACWJCkAA7AkSQHEuhGHJElKcwQsSVIABmBJkgIwAEuSFIABWJKkAAzAkiQFYACWJCkAA7AkSQEYgCVJCsAALElSAAZgSZICMABLkhSAAViSpAAMwJIkBWAAliQpAAOwJEkBGIAlSQrAACxJUgAGYEmSAjAAS5IUgAFYkqQADMCSJAVgAJYkKQADsCRJAfwPVUxC5Ncuf1kAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1055f1f60>"
"<Figure size 576x144 with 5 Axes>"
]
},
"metadata": {},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"inpath = '../../examples/air/data'\n",
"(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)\n",
"inpath = '../../examples/air/.data'\n",
"X_np, _ = multi_mnist.load(inpath)\n",
"X_np = X_np.astype(np.float32)\n",
"X_np /= 255.0\n",
"mnist = torch.from_numpy(X_np)\n",
Expand Down

0 comments on commit bd9e691

Please sign in to comment.