forked from mlcommons/GaNDLF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gandlf_collectStats
172 lines (147 loc) · 6.09 KB
/
gandlf_collectStats
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from GANDLF.cli import copyrightMessage
def plot_all(df_training, df_validation, df_testing, output_plot_dir):
"""
Plots training, validation, and testing data for loss and other metrics.
TODO: this function needs to be moved under utils and then called after every training epoch.
Args:
df_training (pd.DataFrame): DataFrame containing training data.
df_validation (pd.DataFrame): DataFrame containing validation data.
df_testing (pd.DataFrame): DataFrame containing testing data.
output_plot_dir (str): Directory to save the plots.
Returns:
tuple: Tuple containing the modified training, validation, and testing DataFrames.
"""
# Drop any columns that might have "_" in the values of their rows
banned_cols = [
col
for col in df_training.columns
if any("_" in str(val) for val in df_training[col].values)
]
# Determine metrics from the column names by removing the "train_" prefix
metrics = [
col.replace("train_", "")
for col in df_training.columns
if "train_" in col and col not in banned_cols
]
# Split the values of the banned columns into multiple columns
# for df in [df_training, df_validation, df_testing]:
# for col in banned_cols:
# if df[col].dtype == "object":
# split_cols = (
# df[col]
# .str.split("_", expand=True)
# .apply(pd.to_numeric, errors="coerce")
# )
# split_cols.columns = [f"{col}_{i}" for i in range(split_cols.shape[1])]
# df.drop(columns=col, inplace=True)
# df = pd.concat([df, split_cols], axis=1)
# Check if any of the metrics is present in the column names of the dataframe
assert any(
any(metric in col for col in df_training.columns) for metric in metrics
), "None of the specified metrics is in the dataframe."
required_cols = ["epoch_no", "train_loss"]
# Check if the required columns are in the dataframe
assert all(
col in df_training.columns for col in required_cols
), "Not all required columns are in the dataframe."
epochs = len(df_training)
# Plot for loss
plt.figure(figsize=(12, 6))
if "train_loss" in df_training.columns:
sns.lineplot(data=df_training, x="epoch_no", y="train_loss", label="Training")
if "valid_loss" in df_validation.columns:
sns.lineplot(
data=df_validation, x="epoch_no", y="valid_loss", label="Validation"
)
if df_testing is not None and "test_loss" in df_testing.columns:
sns.lineplot(data=df_testing, x="epoch_no", y="test_loss", label="Testing")
plt.xlim(0, epochs - 1)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Plot")
plt.legend()
Path(output_plot_dir).mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(output_plot_dir, "loss_plot.png"), dpi=300)
plt.close()
# Plot for other metrics
for metric in metrics:
metric_cols = [col for col in df_training.columns if metric in col]
for metric_col in metric_cols:
plt.figure(figsize=(12, 6))
if metric_col in df_training.columns:
sns.lineplot(
data=df_training,
x="epoch_no",
y=metric_col,
label=f"Training {metric_col}",
)
if metric_col.replace("train", "valid") in df_validation.columns:
sns.lineplot(
data=df_validation,
x="epoch_no",
y=metric_col.replace("train", "valid"),
label=f"Validation {metric_col}",
)
if (
df_testing is not None
and metric_col.replace("train", "test") in df_testing.columns
):
sns.lineplot(
data=df_testing,
x="epoch_no",
y=metric_col.replace("train", "test"),
label=f"Testing {metric_col}",
)
plt.xlim(0, epochs - 1)
plt.xlabel("Epoch")
plt.ylabel(metric.capitalize())
plt.title(f"{metric.capitalize()} Plot")
plt.legend()
plt.savefig(os.path.join(output_plot_dir, f"{metric}_plot.png"), dpi=300)
plt.close()
print("Plots saved successfully.")
return df_training, df_validation, df_testing
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="GANDLF_CollectStats",
formatter_class=argparse.RawTextHelpFormatter,
description="Collect statistics from different testing/validation combinations from output directory.\n\n"
+ copyrightMessage,
)
parser.add_argument(
"-m",
"--modeldir",
metavar="",
type=str,
help="Input directory which contains testing and validation models log files",
)
parser.add_argument(
"-o",
"--outputdir",
metavar="",
type=str,
help="Output directory to save stats and plot",
)
args = parser.parse_args()
inputDir = os.path.normpath(args.modeldir)
outputDir = os.path.normpath(args.outputdir)
Path(outputDir).mkdir(parents=True, exist_ok=True)
outputFile = os.path.join(outputDir, "data.csv") # data file name
outputPlot = os.path.join(outputDir, "plot.png") # plot file
trainingLogs = os.path.join(inputDir, "logs_training.csv")
validationLogs = os.path.join(inputDir, "logs_validation.csv")
testingLogs = os.path.join(inputDir, "logs_testing.csv")
# Read all the files
df_training = pd.read_csv(trainingLogs)
df_validation = pd.read_csv(validationLogs)
df_testing = pd.read_csv(testingLogs) if os.path.isfile(testingLogs) else None
# Check for metrics in columns and do tight plots
plot_all(df_training, df_validation, df_testing, outputPlot)