-
Notifications
You must be signed in to change notification settings - Fork 726
/
infer.py
164 lines (140 loc) · 6.44 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import os
import paddle.nn as nn
import time
import logging
import sys
import importlib
from eval_utils import evaluate
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append('../../../tools')
from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
from utils.save_load import save_model, load_model
from paddle.io import DistributedBatchSampler, DataLoader
import argparse
import numpy as np
from importlib import import_module
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description='paddle-rec run')
parser.add_argument("-m", "--config_yaml", type=str)
parser.add_argument("-o", "--opt", nargs='*', type=str)
args = parser.parse_args()
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
args.config_yaml = get_abs_model(args.config_yaml)
return args
def main(args):
paddle.seed(12345)
# load config
config = load_yaml(args.config_yaml)
dy_model_class = load_dy_model_class(args.abs_dir)
config["config_abs_dir"] = args.abs_dir
# modify config from command
if args.opt:
for parameter in args.opt:
parameter = parameter.strip()
key, value = parameter.split("=")
if type(config.get(key)) is int:
value = int(value)
if type(config.get(key)) is float:
value = float(value)
if type(config.get(key)) is bool:
value = (True if value.lower() == "true" else False)
config[key] = value
# tools.vars
use_gpu = config.get("runner.use_gpu", True)
use_xpu = config.get("runner.use_xpu", False)
use_visual = config.get("runner.use_visual", False)
test_data_dir = config.get("runner.test_data_dir", None)
print_interval = config.get("runner.print_interval", None)
infer_batch_size = config.get("runner.infer_batch_size", None)
model_load_path = config.get("runner.infer_load_path", "model_output")
start_epoch = config.get("runner.infer_start_epoch", 0)
end_epoch = config.get("runner.infer_end_epoch", 10)
logger.info("**************common.configs**********")
logger.info(
"use_gpu: {}, use_xpu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}".
format(use_gpu, use_xpu, use_visual, infer_batch_size, test_data_dir,
start_epoch, end_epoch, print_interval, model_load_path))
logger.info("**************common.configs**********")
if use_xpu:
xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
place = paddle.set_device(xpu_device)
else:
place = paddle.set_device('gpu' if use_gpu else 'cpu')
dy_model = dy_model_class.create_model(config)
# Create a log_visual object and store the data in the path
if use_visual:
from visualdl import LogWriter
log_visual = LogWriter(args.abs_dir + "/visualDL_log/infer")
# to do : add optimizer function
# optimizer = dy_model_class.create_optimizer(dy_model, config)
logger.info("read data")
test_dataloader = create_data_loader(config, place)
test_dataloader = paddle.io.DataLoader(
test_dataloader.dataset, batch_size=1)
epoch_begin = time.time()
interval_begin = time.time()
metric_list, metric_list_name = dy_model_class.create_metrics()
step_num = 0
print(len(test_dataloader))
for epoch_id in range(start_epoch, end_epoch):
logger.info("load model epoch {}".format(epoch_id))
model_path = os.path.join(model_load_path, str(epoch_id))
load_model(model_path, dy_model)
dy_model.eval()
infer_reader_cost = 0.0
infer_run_cost = 0.0
reader_start = time.time()
y_pred = []
for batch_id, batch in enumerate(test_dataloader):
infer_reader_cost += time.time() - reader_start
infer_start = time.time()
with paddle.no_grad():
metric_list, tensor_print_dict = dy_model_class.infer_forward(
dy_model, metric_list, batch, config)
y_pred.append(tensor_print_dict['y_pred'].cpu().numpy())
infer_run_cost += time.time() - infer_start
if batch_id % print_interval == 0:
logger.info(
"epoch: {}, batch_id: {}, ".format(epoch_id, batch_id) +
" avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.2f} ins/s".
format(infer_reader_cost / print_interval, (
infer_reader_cost + infer_run_cost) / print_interval,
infer_batch_size, print_interval * infer_batch_size
/ (time.time() - interval_begin)))
interval_begin = time.time()
infer_reader_cost = 0.0
infer_run_cost = 0.0
reader_start = time.time()
y_pred = np.concatenate(y_pred, 0)
AUC, MRR, nDCG5, nDCG10 = evaluate(y_pred,
test_dataloader.dataset.test_labels,
test_dataloader.dataset.test_bound)
metric_str = "AUC : {:.4f}, ".format(AUC) + " MRR : {:.4f}, ".format(
MRR) + "nDCG5 : {:.4f}, ".format(nDCG5) + "nDCG10 : {:.4f}".format(
nDCG10)
tensor_print_str = ""
logger.info("epoch: {} done, ".format(epoch_id) + metric_str +
tensor_print_str + " epoch time: {:.2f} s".format(
time.time() - epoch_begin))
epoch_begin = time.time()
if __name__ == '__main__':
args = parse_args()
main(args)