-
Notifications
You must be signed in to change notification settings - Fork 0
/
s3_s2_predict.py
74 lines (47 loc) · 2.07 KB
/
s3_s2_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
# import glob
# import time
# import datetime
# import pandas as pd
from utils.settings_builder import Settings
from utils.model_predict import run_models, run_tasks
# *****************
# *****************
json_path = "settings/settings_example.json"
json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), json_path)
# *****************
# *****************
# json_path = "settings/ghana_2008_dhs.json"
s = Settings()
s.load(json_path)
base_path = s.base_path
mode = s.config["second_stage_mode"]
predict_hash = s.build_hash(s.data[s.config["predict"]], nchar=7)
# timestamp = datetime.datetime.fromtimestamp(int(time.time())).strftime(
# '%Y_%m_%d_%H_%M_%S')
s3_info = s.data["third_stage"]
model_inputs = s3_info["predict"]["inputs"]
model_tag = s.config["model_tag"]
predict_settings = s.data[s.config["predict"]]
predict_hash = s.build_hash(predict_settings, nchar=7)
tasks = s.hashed_iter()
qlist = []
for ix, (param_hash, params) in enumerate(tasks):
grid_predict_id = "{}_{}".format(s3_info["grid"]["boundary_id"], s3_info["predict"]["imagery_year"])
grid_id_string = "{}_{}_{}_{}".format(
param_hash, grid_predict_id, s.config["version"], s.config["predict_tag"]
)
grid_predict_path = os.path.join(base_path, "output/s3_s1_predict/predict_{}.csv".format(grid_id_string))
train_predict_id = predict_hash
train_id_string = "{}_{}_{}_{}".format(
param_hash, train_predict_id, s.config["version"], s.config["predict_tag"]
)
for name in s3_info["predict"]["class_models"]:
joblib_path = os.path.join(base_path, "output/s2_models/models_{}_INPUT_{}_{}.joblib".format(
name, train_id_string, model_tag))
qlist.append((grid_predict_path, joblib_path, "class"))
for name in s3_info["predict"]["proba_models"]:
joblib_path = os.path.join(base_path, "output/s2_models/models_{}_INPUT_{}_{}.joblib".format(
name, train_id_string, model_tag))
qlist.append((grid_predict_path, joblib_path, "proba"))
run_tasks(tasks=qlist, func=run_models, args=s, mode=mode)