-
Notifications
You must be signed in to change notification settings - Fork 19
/
main_imagenet_test.py
executable file
·127 lines (95 loc) · 3.4 KB
/
main_imagenet_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
"""
ImageNet testing script modifiled from BaseCls
https://github.com/megvii-research/basecls/blob/main/basecls/tools/cls_test.py
"""
import argparse
import importlib
import multiprocessing as mp
import os
import sys
import megengine as mge
import megengine.distributed as dist
from basecore.config import ConfigDict
from loguru import logger
from basecls.engine import ClsTester
from basecls.models import build_model, load_model
from basecls.utils import default_logging, registers, set_nccl_env, set_num_threads, setup_logger
from model_replknet import RepLKNet
def make_parser() -> argparse.ArgumentParser:
"""Build args parser for testing script.
Returns:
The args parser.
"""
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", type=str, help="testing process description file")
parser.add_argument("-w", "--weight_file", default=None, type=str, help="weight file")
parser.add_argument("opts", default=None, nargs=argparse.REMAINDER, help="other options")
return parser
@logger.catch
def worker(args: argparse.Namespace):
"""Worker function for testing script.
Args:
args: args for testing script.
"""
logger.info(f"Init process group for gpu{dist.get_rank()} done")
sys.path.append(os.path.dirname(args.file))
module_name = os.path.splitext(os.path.basename(args.file))[0]
current_network = importlib.import_module(module_name)
cfg = current_network.Cfg()
if cfg.output_dir is None:
cfg.output_dir = f"./logs_{module_name}"
cfg.output_dir = os.path.abspath(cfg.output_dir)
if args.weight_file:
cfg.weights = args.weight_file
else:
cfg.weights = os.path.join(cfg.output_dir, "latest.pkl")
cfg.merge(args.opts)
cfg.set_mode("freeze")
if dist.get_rank() == 0 and not os.path.exists(cfg.output_dir):
os.makedirs(cfg.output_dir)
dist.group_barrier()
setup_logger(cfg.output_dir, "test_log.txt", to_loguru=True)
logger.info(f"args: {args}")
if cfg.fastrun:
logger.info("Using fastrun mode...")
mge.functional.debug_param.set_execution_strategy("PROFILE")
tester = build(cfg)
tester.test()
def build(cfg: ConfigDict):
"""Build function for testing script.
Args:
cfg: config for testing.
Returns:
A tester.
"""
model = build_model(cfg)
load_model(model, cfg.weights)
if isinstance(model, RepLKNet):
model = RepLKNet.convert_to_deploy(model)
default_logging(cfg, model)
dataloader = registers.dataloaders.get(cfg.data.name).build(cfg, False)
# FIXME: need atomic user_pop, maybe in MegEngine 1.5?
# tester = BaseTester(model, dataloader, AccEvaluator())
return ClsTester(cfg, model, dataloader)
def main():
"""Main function for testing script."""
parser = make_parser()
args = parser.parse_args()
mp.set_start_method("spawn")
set_nccl_env()
set_num_threads()
if not os.path.exists(args.file):
raise ValueError("Description file does not exist")
device_count = mge.device.get_device_count("gpu")
if device_count == 0:
logger.warning("No GPU was found, testing on CPU")
worker(args)
elif device_count > 1:
mp_worker = dist.launcher(worker)
mp_worker(args)
else:
worker(args)
if __name__ == "__main__":
main()