-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
165 lines (141 loc) · 5.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python
import sys
import os
import wandb
import torch
import socket
import numpy as np
from arguments import *
from env import MultiCellNetEnv
from utils import *
from env.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv
if DEBUG and 'loadDebugger' in globals():
loadDebugger()
def get_env_kwargs(args):
return {k: v for k, v in vars(args).items() if v is not None}
def get_default_env_config(args, env_args):
tmp_env = MultiCellNetEnv(**get_env_kwargs(env_args))
tmp_env.print_info()
# tmp_env.net.traffic_model.print_info()
args.__dict__.update(
episode_length=tmp_env.episode_len // args.n_rollout_threads,
episode_secs=tmp_env.episode_time_len,
avg_traffic_density=tmp_env.net.traffic_model.density_mean,
traffic_density_std=tmp_env.net.traffic_model.density_std,
accelerate=tmp_env.net.accelerate,
# w_pc=tmp_env.w_pc,
w_qos=tmp_env.w_qos,
w_xqos=tmp_env.w_xqos,
# w_drop=tmp_env.w_drop,
# w_delay=tmp_env.w_delay,
)
def make_env(args, env_args, for_eval=False):
n_threads = args.n_rollout_threads
def get_env_fn(rank):
def init_env():
kwargs = get_env_kwargs(env_args)
kwargs['start_time'] = rank / n_threads * args.episode_secs
kwargs['episode_len'] = args.episode_length
env = MultiCellNetEnv(**kwargs)
if for_eval:
env.seed(args.seed * 50000 + rank * 10000)
else:
env.seed(args.seed + rank * 1000)
return env
return init_env
if n_threads == 1:
return ShareDummyVecEnv([get_env_fn(0)])
return ShareSubprocVecEnv([get_env_fn(i) for i in range(n_threads)])
def main(args):
parser = get_config()
env_parser = get_env_config()
args, env_args = parser.parse_known_args(args)
env_args, rl_args = env_parser.parse_known_args(env_args)
if args.algorithm_name == "rmappo":
assert (args.use_recurrent_policy or args.use_naive_recurrent_policy), (
"check recurrent policy!")
from trainers.mappo_trainer import MappoTrainer as Trainer, get_mappo_config
rl_parser = get_mappo_config()
elif args.algorithm_name == "mappo":
args.use_recurrent_policy = False
args.use_naive_recurrent_policy = False
from trainers.mappo_trainer import MappoTrainer as Trainer, get_mappo_config
rl_parser = get_mappo_config()
elif args.algorithm_name == "dqn":
from trainers.dqn_trainer import DQNTrainer as Trainer, get_dqn_config
rl_parser = get_dqn_config()
else:
raise NotImplementedError
rl_args = rl_parser.parse_args(rl_args)
vars(args).update(vars(rl_args))
# cuda
if args.cuda and torch.cuda.is_available():
print("choose to use gpu...")
device = torch.device("cuda:0")
torch.set_num_threads(args.n_training_threads)
if args.cuda_deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
else:
print("choose to use cpu...")
device = torch.device("cpu")
torch.set_num_threads(args.n_training_threads)
set_log_level(args.log_level)
if args.sim_log_path is None:
fn = '{}_{}_{}.log'.format(
env_args.scenario, args.algorithm_name, args.experiment_name)
args.sim_log_path = 'logs/' + fn
set_log_file(args.sim_log_path)
# get env config
get_default_env_config(args, env_args)
# run dir
run_dir = get_run_dir(args, env_args)
if not run_dir.exists():
os.makedirs(str(run_dir))
# logging
if args.use_wandb:
wandb.init(
config=args,
project=args.env_name,
entity=args.user_name,
notes=socket.gethostname(),
name=f"{args.algorithm_name}_{args.experiment_name}_seed{args.seed}",
group=env_args.scenario,
dir=str(run_dir),
job_type="training",
reinit=True,
)
else:
if not run_dir.exists():
curr_run = 'run1'
else:
exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in
run_dir.iterdir() if str(folder.name).startswith('run')]
if len(exst_run_nums) == 0:
curr_run = 'run1'
else:
curr_run = 'run%i' % (max(exst_run_nums) + 1)
run_dir = run_dir / curr_run
if not run_dir.exists():
os.makedirs(str(run_dir))
# seed
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# env init
envs = make_env(args, env_args)
eval_envs = make_env(args, env_args, for_eval=True) if args.use_eval else None
config = {
"all_args": args,
"envs": envs,
"eval_envs": eval_envs,
"num_agents": MultiCellNetEnv.num_agents,
"device": device,
"run_dir": run_dir
}
trainer = Trainer(config)
trainer.train()
trainer.close()
if __name__ == "__main__":
main(sys.argv[1:])