forked from mindspore-lab/mindcv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
validate.py
101 lines (84 loc) · 2.71 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import mindspore as ms
import mindspore.nn as nn
from mindspore import Model
from mindcv.data import create_dataset, create_loader, create_transforms
from mindcv.loss import create_loss
from mindcv.models import create_model
from mindcv.utils import ValCallback, check_batch_size
from config import parse_args # isort: skip
def validate(args):
ms.set_context(mode=args.mode)
# create dataset
dataset_eval = create_dataset(
name=args.dataset,
root=args.data_dir,
split=args.val_split,
num_parallel_workers=args.num_parallel_workers,
download=args.dataset_download,
)
# create transform
transform_list = create_transforms(
dataset_name=args.dataset,
is_training=False,
image_resize=args.image_resize,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
mean=args.mean,
std=args.std,
)
# read num clases
num_classes = dataset_eval.num_classes() if args.num_classes is None else args.num_classes
# check batch size
batch_size = check_batch_size(dataset_eval.get_dataset_size(), args.batch_size)
# load dataset
loader_eval = create_loader(
dataset=dataset_eval,
batch_size=batch_size,
drop_remainder=False,
is_training=False,
transform=transform_list,
num_parallel_workers=args.num_parallel_workers,
)
# create model
network = create_model(
model_name=args.model,
num_classes=num_classes,
drop_rate=args.drop_rate,
drop_path_rate=args.drop_path_rate,
pretrained=args.pretrained,
checkpoint_path=args.ckpt_path,
ema=args.ema,
)
network.set_train(False)
# create loss
loss = create_loss(
name=args.loss,
reduction=args.reduction,
label_smoothing=args.label_smoothing,
aux_factor=args.aux_factor,
)
# Define eval metrics.
if num_classes >= 5:
eval_metrics = {
"Top_1_Accuracy": nn.Top1CategoricalAccuracy(),
"Top_5_Accuracy": nn.Top5CategoricalAccuracy(),
"loss": nn.metrics.Loss(),
}
else:
eval_metrics = {
"Top_1_Accuracy": nn.Top1CategoricalAccuracy(),
"loss": nn.metrics.Loss(),
}
# init model
model = Model(network, loss_fn=loss, metrics=eval_metrics)
# log
num_batches = loader_eval.get_dataset_size()
print(f"Model: {args.model}")
print(f"Num batches: {num_batches}")
print("Start validating...")
# validate
result = model.eval(loader_eval, dataset_sink_mode=False, callbacks=[ValCallback(args.log_interval)])
print(result)
if __name__ == "__main__":
args = parse_args()
validate(args)