-
Notifications
You must be signed in to change notification settings - Fork 1
/
trainer_pooled.py
134 lines (101 loc) · 4.14 KB
/
trainer_pooled.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import ray
from ray import tune
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.models import ModelCatalog
from custom_model import basic_model_with_masking, Generalized_model_with_masking
from env import Volunteers_Dilemma
from utils import custom_eval_function, MyCallbacks, get_args
import numpy as np
POLICIES = ['policy_0','policy_1','policy_2','policy_3','policy_4','policy_5']
def policy_mapping_fn(agent_id):
return np.random.choice(POLICIES)
def setup(args):
env = Volunteers_Dilemma(vars(args))
obs_space = env.observation_space
action_space = env.action_space
ModelCatalog.register_custom_model("basic_model", basic_model_with_masking)
ModelCatalog.register_custom_model("generalized_model_with_masking", Generalized_model_with_masking)
config = {
"env": Volunteers_Dilemma,
"env_config": vars(args),
"num_workers": args.n_workers,
"framework": "torch",
"num_gpus": args.n_gpus,
"lr": 1e-3,
"callbacks": MyCallbacks,
}
policies = {}
for policy in args.policies:
policies[policy] = (None, obs_space, action_space, {"framework":"torch", "beta":args.policies[policy]})
policies_to_train = [policy for policy in args.policies]
config["multiagent"] = {
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": policies_to_train
}
# Conduct evaluation and custom metrics during training
if args.evaluate_during_training:
# Evaluation
config["evaluation_num_workers"] = 1
# Optional custom eval function.
config["custom_eval_function"] = custom_eval_function
# Enable evaluation, once per training iteration.
config["evaluation_interval"] = 1
# Run 10 episodes each time evaluation runs.
config["evaluation_num_episodes"] = 100,
# Override the env config for evaluation.
config["evaluation_config"] = {
"env_config": vars(args),
"explore": False
}
# Discrete action space
if args.discrete:
config['exploration_config']= {
"type": "EpsilonGreedy",
"initial_epsilon": args.initial_epsilon,
"final_epsilon": args.final_epsilon,
"epsilon_timesteps": args.stop_iters,
}
if args.basic_model:
config['model'] = {
"custom_model": "basic_model",
"custom_model_config": {
}
}
else:
config['model'] = {
"custom_model": "generalized_model_with_masking",
"custom_model_config": {
'args': args,
'num_embeddings': args.max_system_value,
},
}
if hasattr(args, 'reveal_other_agents_identity'):
config['model']['custom_model_config']['full_information'] = args.full_information
if hasattr(args, 'reveal_other_agents_identity'):
config['model']['custom_model_config']['reveal_other_agents_identity'] = args.reveal_other_agents_identity
if hasattr(args, 'reveal_other_agents_beta'):
config['model']['custom_model_config']['reveal_other_agents_beta'] = args.reveal_other_agents_beta
if args.n_samples == 1:
config['seed'] = args.seed
stop = {
"training_iteration" : args.stop_iters,
}
if args.run == "DQN":
config['hiddens'] = []
config['dueling'] = False
return config, stop
if __name__ == "__main__":
args=get_args()
ray.init(local_mode = args.local_mode)
config, stop = setup(args)
results = tune.run( args.run,
config=config,
stop=stop,
local_dir=args.log_dir,
checkpoint_freq = args.checkpoint_frequency,
num_samples = args.n_samples,
)
if args.as_test:
check_learning_achieved(results, stop['episode_reward_mean'])
ray.shutdown()