-
Notifications
You must be signed in to change notification settings - Fork 0
/
postprocess.py
61 lines (53 loc) · 1.73 KB
/
postprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from utils import logger
def plot_cost(args, results):
mc_data = pd.concat([
pd.DataFrame(
{'name': 'mc',
'error': res[0],
'x': np.arange(len(res[0]))}) for res in results
])
rqmc_data = pd.concat([
pd.DataFrame(
{'name': 'rqmc',
'error': res[1],
'x': np.arange(len(res[1]))}) for res in results
])
arqmc_data = pd.concat([
pd.concat([
pd.DataFrame(
{'name': name,
'error': error,
'x': np.arange(len(error))})
for name, error in res[2].items()
])
for res in results
])
data = pd.concat([mc_data, rqmc_data, arqmc_data])
plot = sns.relplot(x='x', y='error', kind='line', hue='name', data=data)
plot.set(yscale='log')
plt.savefig(args.save_fig)
def plot_learn(args, full_results):
mc_discard_threshold = 3
Path(args.save_fig).parent.mkdir(parents=True, exist_ok=True)
logger.info('ploting {}'.format(args.save_fig))
results = [res for res, info in full_results if len(info['out']) == 0]
if len(results) < mc_discard_threshold:
results = [res for res, info in full_results if len(info['out']) == 1]
if len(results) == 0: return
data = pd.concat([
pd.concat([
pd.DataFrame({
'name': name,
'cost': -val,
'x': np.arange(len(val)),
})
for name, val in res.items()
])
for res in results
])
plot = sns.relplot(x='x', y='cost', kind='line', hue='name', data=data)
plt.savefig(args.save_fig)