-
Notifications
You must be signed in to change notification settings - Fork 93
/
test.py
142 lines (114 loc) · 4.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
import os
import numpy as np
import pytorch_lightning as pl
import torch
from pathlib import Path
from rich import get_console
from rich.table import Table
from omegaconf import OmegaConf
from mGPT.callback import build_callbacks
from mGPT.config import parse_args
from mGPT.data.build_data import build_data
from mGPT.models.build_model import build_model
from mGPT.utils.logger import create_logger
from mGPT.utils.load_checkpoint import load_pretrained, load_pretrained_vae
def print_table(title, metrics, logger=None):
table = Table(title=title)
table.add_column("Metrics", style="cyan", no_wrap=True)
table.add_column("Value", style="magenta")
for key, value in metrics.items():
table.add_row(key, str(value))
console = get_console()
console.print(table, justify="center")
logger.info(metrics) if logger else None
def get_metric_statistics(values, replication_times):
mean = np.mean(values, axis=0)
std = np.std(values, axis=0)
conf_interval = 1.96 * std / np.sqrt(replication_times)
return mean, conf_interval
def main():
# parse options
cfg = parse_args(phase="test") # parse config file
cfg.FOLDER = cfg.TEST.FOLDER
# Logger
logger = create_logger(cfg, phase="test")
logger.info(OmegaConf.to_yaml(cfg))
# Output dir
model_name = cfg.model.target.split('.')[-2].lower()
output_dir = Path(
os.path.join(cfg.FOLDER, model_name, cfg.NAME, "samples_" + cfg.TIME))
if cfg.TEST.SAVE_PREDICTIONS:
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving predictions to {str(output_dir)}")
# Seed
pl.seed_everything(cfg.SEED_VALUE)
# Environment Variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Callbacks
callbacks = build_callbacks(cfg, logger=logger, phase="test")
logger.info("Callbacks initialized")
# Dataset
datamodule = build_data(cfg)
logger.info("datasets module {} initialized".format("".join(
cfg.DATASET.target.split('.')[-2])))
# Model
model = build_model(cfg, datamodule)
logger.info("model {} loaded".format(cfg.model.target))
# Lightning Trainer
trainer = pl.Trainer(
benchmark=False,
max_epochs=cfg.TRAIN.END_EPOCH,
accelerator=cfg.ACCELERATOR,
devices=list(range(len(cfg.DEVICE))),
default_root_dir=cfg.FOLDER_EXP,
reload_dataloaders_every_n_epochs=1,
deterministic=False,
detect_anomaly=False,
enable_progress_bar=True,
logger=None,
callbacks=callbacks,
)
# Strict load vae model
if cfg.TRAIN.PRETRAINED_VAE:
load_pretrained_vae(cfg, model, logger)
# loading state dict
if cfg.TEST.CHECKPOINTS:
load_pretrained(cfg, model, logger, phase="test")
else:
logger.warning("No checkpoints provided!!!")
# Calculate metrics
all_metrics = {}
replication_times = cfg.TEST.REPLICATION_TIMES
for i in range(replication_times):
metrics_type = ", ".join(cfg.METRIC.TYPE)
logger.info(f"Evaluating {metrics_type} - Replication {i}")
metrics = trainer.test(model, datamodule=datamodule)[0]
if "TM2TMetrics" in metrics_type and cfg.model.params.task == "t2m" and cfg.model.params.stage != 'vae':
# mm meteics
logger.info(f"Evaluating MultiModality - Replication {i}")
datamodule.mm_mode(True)
mm_metrics = trainer.test(model, datamodule=datamodule)[0]
# metrics.update(mm_metrics)
metrics.update(mm_metrics)
datamodule.mm_mode(False)
for key, item in metrics.items():
if key not in all_metrics:
all_metrics[key] = [item]
else:
all_metrics[key] += [item]
all_metrics_new = {}
for key, item in all_metrics.items():
mean, conf_interval = get_metric_statistics(np.array(item),
replication_times)
all_metrics_new[key + "/mean"] = mean
all_metrics_new[key + "/conf_interval"] = conf_interval
print_table(f"Mean Metrics", all_metrics_new, logger=logger)
all_metrics_new.update(all_metrics)
# Save metrics to file
metric_file = output_dir.parent / f"metrics_{cfg.TIME}.json"
with open(metric_file, "w", encoding="utf-8") as f:
json.dump(all_metrics_new, f, indent=4)
logger.info(f"Testing done, the metrics are saved to {str(metric_file)}")
if __name__ == "__main__":
main()