Skip to content

Commit

Permalink
update compare script to add per topic analysis plot (#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
Victor0118 authored and lintool committed Jan 25, 2019
1 parent 927e7ff commit 951090b
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/main/python/compare_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import numpy as np
import scipy.stats
import statistics
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from operator import itemgetter


def load_metrics(file):
Expand All @@ -46,6 +49,22 @@ def load_metrics(file):
return metrics


def plot(all_results, output_path="."):
fig, ax = plt.subplots(1, 1, figsize=(16, 3))
all_results.sort(key = itemgetter(1), reverse=True)
x = [_x+0.5 for _x in range(len(all_results))]
y = [float(ele[1]) for ele in all_results]
ax.bar(x, y, width=0.6, align='edge')
ax.set_xticks(x)
ax.set_xticklabels([int(ele[0]) for ele in all_results], {'fontsize': 5}, rotation='vertical')
ax.grid(True)
ax.set_title("Per-topic analysis on {}".format(metric))
ax.set_xlabel('Topics')
ax.set_ylabel('{} Diff'.format(metric))
ax.set_ylim(-0.4, 0.55)
output_fn = os.path.join(output_path, 'per_query_{}.eps'.format(metric))
plt.savefig(output_fn, bbox_inches='tight', format='eps')

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--base", type=str, help='base run', required=True)
Expand All @@ -69,10 +88,12 @@ def load_metrics(file):
if "." in metric:
metric = "_".join(metric.split("."))

all_results = []
for key in base_metrics[metric]:
base_score = base_metrics[metric][key]
comp_score = comp_metrics[metric][key]
diff = comp_score - base_score
all_results.append((key, diff))
print(f'{key}\t{base_score:.4}\t{comp_score:.4}\t{diff:.4}')

# Extract the paired scores
Expand All @@ -83,3 +104,5 @@ def load_metrics(file):
print(f'base mean: {np.mean(a):.4}')
print(f'comp mean: {np.mean(b):.4}')
print(f't-statistic: {tstat:.6}, p-value: {pvalue:.6}')

plot(all_results)

0 comments on commit 951090b

Please sign in to comment.