-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_plots.py
95 lines (77 loc) · 3.87 KB
/
run_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
def plot_stats(data_path, combinations, statistic='median', colorblind=True):
"""
Plot population size and infection rate over time for specified combinations of Wolbachia effects.
Args:
data_path (str): Path to the CSV file with precomputed statistics.
combinations (list of str): List of combination strings to plot.
statistic (str): 'median' to plot median with CI, 'mean' to plot mean with SEM.
colorblind (bool): If True, use a colorblind-friendly palette.
"""
# Load the precomputed statistics
df = pd.read_csv(data_path)
print(df.combination.unique())
# Set colorblind-friendly palette if requested
if colorblind:
palette = sns.color_palette("colorblind")
figure_handle = plt.figure(figsize=(12, 6))
# Population Size Plot
ax1 = plt.subplot(1, 2, 1)
plot_combination(ax1, df, combinations, 'pop_size', statistic)
# Infection Rate Plot
ax2 = plt.subplot(1, 2, 2)
plot_combination(ax2, df, combinations, 'infection_rate', statistic)
plt.tight_layout()
return figure_handle
def plot_combination(ax, df, combinations, column_name, statistic):
"""
Plots a specific statistic for a set of combinations on a given axis.
Args:
ax (matplotlib.axes.Axes): The axis to plot on.
df (pd.DataFrame): The DataFrame containing the data.
combinations (list): The list of combinations to plot.
column_name (str): The name of the column to plot ('Population Size' or 'Infection Rate').
statistic (str): 'median' or 'mean'.
"""
for comb in combinations:
comb_data = df[df['combination'] == comb]
if statistic == 'median':
ax.fill_between(comb_data['day'], comb_data[f'{column_name.lower()}_ci_lower'], comb_data[f'{column_name.lower()}_ci_upper'], alpha=0.3)
ax.plot(comb_data['day'], comb_data[f'{column_name.lower()}_median'], label=comb)
elif statistic == 'mean':
ax.fill_between(comb_data['day'], comb_data[f'{column_name.lower()}_mean'] - comb_data[f'{column_name.lower()}_sem'], comb_data[f'{column_name.lower()}_mean'] + comb_data[f'{column_name.lower()}_sem'], alpha=0.3)
ax.plot(comb_data['day'], comb_data[f'{column_name.lower()}_mean'], label=comb)
ax.set_title(f'{column_name} Over Time')
ax.set_xlabel('Days')
ax.set_ylabel(column_name)
ax.legend()
def save_figure(figure_handle, combinations, save_path, file_format=['png', 'svg']):
"""
Save the figure in specified formats using a filename derived from the combinations.
Args:
figure_handle (matplotlib.figure.Figure): The figure handle to save.
combinations (list of str): The list of combination strings used in the plot.
save_path (str): The directory path to save the file.
file_format (list of str): List of formats to save the figure. Default is ['png', 'svg'].
"""
# Create a filename based on the combinations
filename_base = '_'.join(combinations)
# Save in each specified file format
for fmt in file_format:
file_path = f'{save_path}/{filename_base}.{fmt}'
figure_handle.savefig(file_path, format=fmt)
# Example usage
save_path = './figures' # Update this to your save path
data_path = 'wolbachia_stats.csv' # Path to your statistics CSV file
combinations_to_plot = [['er', 'ci', 'mk', 'eg', 'no_effects'],
['cimkereg','cimkeg', 'mkeg', 'mk', 'no_effects'],
['er', 'ci', 'cier', 'no_effects'],
['er', 'mk', 'mker', 'no_effects'],
['er', 'eg', 'ereg', 'no_effects']
]
for combi in combinations_to_plot:
f_handle = plot_stats(data_path, combi, statistic='mean', colorblind=True)
save_figure(f_handle, combi, save_path)
plt.show()