-
Notifications
You must be signed in to change notification settings - Fork 2
/
demo_pacoh_nn.py
122 lines (94 loc) · 4.74 KB
/
demo_pacoh_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import tensorflow as tf
class SinusoidEnv():
def __init__(self, amp_low=2.0, amp_high=3.0, x_shift_low=-2.0, x_shift_high=2.0,
x_low=-4.0, x_high=4.0, noise_std=0.1, seed=234):
self.amp_low = amp_low
self.amp_high = amp_high
self.x_shift_low = x_shift_low
self.x_shift_high = x_shift_high
self.x_low = x_low
self.x_high = x_high
self.noise_std = noise_std
self.random_state = np.random.RandomState(seed)
def _sample_sinusoid_fn(self):
amplitude = self.random_state.uniform(self.amp_low, self.amp_high)
x_shift = self.random_state.uniform(self.x_shift_low, self.x_shift_high)
return lambda x: amplitude * np.sin((x - x_shift)) + 5.0
def generate_meta_train_data(self, n_tasks, n_samples):
meta_train_tuples = []
for i in range(n_tasks):
f = self._sample_sinusoid_fn()
X = self.random_state.uniform(self.x_low, self.x_high, size=(n_samples, 1))
Y = f(X) + self.noise_std * self.random_state.normal(size=f(X).shape)
meta_train_tuples.append((X, Y))
return meta_train_tuples
def generate_meta_test_data(self, n_tasks, n_samples_context, n_samples_test):
assert n_samples_test > 0
meta_test_tuples = []
for i in range(n_tasks):
f = self._sample_sinusoid_fn()
X = self.random_state.uniform(self.x_low, self.x_high, size=(n_samples_context + n_samples_test, 1))
Y = f(X) + self.noise_std * self.random_state.normal(size=f(X).shape)
meta_test_tuples.append(
(X[:n_samples_context], Y[:n_samples_context], X[n_samples_context:], Y[n_samples_context:]))
return meta_test_tuples
tf.get_logger().setLevel('ERROR')
import warnings
warnings.filterwarnings("ignore")
from matplotlib import pyplot as plt
""" generate meta-learning data from Sinusoid environment """
env = SinusoidEnv()
meta_train_data = env.generate_meta_train_data(n_tasks=200, n_samples=5)
meta_test_data = env.generate_meta_test_data(n_tasks=20, n_samples_context=5, n_samples_test=200)
""" plot some of the meta-learning tasks """
for x_context, y_context, x_test, y_test in meta_test_data[:5]:
plt.scatter(x_test, y_test)
plt.title('Meta-Learning Tasks')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
from pacoh_nn.pacoh_nn_regression import PACOH_NN_Regression
pacoh_model = PACOH_NN_Regression(meta_train_data, random_seed=22, num_iter_meta_train=20000, num_iter_meta_test=3000,
learn_likelihood=False, likelihood_std=0.1, hyper_prior_weight=1e-4)
""" Training a Standard Bayesian Neural Network """
from pacoh_nn.bnn import BayesianNeuralNetworkSVGD
fig, axes = plt.subplots(1, 2, figsize=(12.0, 4.0))
for i in range(2):
x_context, y_context, x_test, y_test = meta_test_data[i]
# setting up and fitting the BNN
bnn = BayesianNeuralNetworkSVGD(x_context, y_context, hidden_layer_sizes=(64, 64, 64, 64), prior_weight=0.001,
bandwidth=1000.0)
bnn.fit(x_val=x_test, y_val=y_test, num_iter_fit=500, log_period=500)
# plotting
x_plot = tf.range(-5, 5, 0.1)
bnn.plot_predictions(x_plot, ax=axes[i])
axes[i].scatter(x_test, y_test, color='blue', alpha=0.2, label="test data")
axes[i].scatter(x_context, y_context, color='red', label="train data")
axes[i].legend()
axes[i].set_xlabel('x')
axes[i].set_xlabel('y')
fig.show()
""" Meta-Training: Meta-Learning a BNN Prior with PACOH-NN """
from pacoh_nn.pacoh_nn_regression import PACOH_NN_Regression
pacoh_model = PACOH_NN_Regression(meta_train_data, random_seed=22, num_iter_meta_train=20000,
num_iter_meta_test=3000,
learn_likelihood=False, likelihood_std=0.1, hyper_prior_weight=1e-4)
pacoh_model.meta_fit(meta_val_data=meta_test_data[:10], eval_period=10000, log_period=1000,
plot_prior_during_training=True, plot_period=10000)
""" Meta-Testing: Posterior inference with the meta-learned PACOH-NN prior"""
fig, axes = plt.subplots(1, 2, figsize=(12.0, 4.0))
for i in range(2):
x_context, y_context, x_test, y_test = meta_test_data[i]
# plotting
x_plot = tf.range(-5, 5, 0.1)
pacoh_model.plot_posterior(x_context, y_context, x_plot, ax=axes[i])
axes[i].scatter(x_test, y_test, color='blue', alpha=0.2, label="test data")
axes[i].scatter(x_context, y_context, color='red', label="train data")
axes[i].legend()
axes[i].set_xlabel('x')
axes[i].set_xlabel('y')
fig.show()
eval_metrics_mean, eval_metrics_std = pacoh_model.meta_eval_datasets(meta_test_data)
for key in eval_metrics_mean:
print("%s: %.4f +- %.4f" % (key, eval_metrics_mean[key], eval_metrics_std[key]))