-
Notifications
You must be signed in to change notification settings - Fork 0
/
RLProject_a2c.py
149 lines (110 loc) · 3.75 KB
/
RLProject_a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os, sys
import gym
from gym.wrappers import RecordVideo
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.vec_env import SubprocVecEnv
sys.path.insert(1, "./highway-env")
import highway_env
import warnings
warnings.filterwarnings("ignore")
TRAIN = False
USE_PREVIOUS_MODEL = False
MANUAL = False
def CreateEnv():
# dl racetrack as baseline
env = gym.make("racetrack-v0")
# General Config ( configure in racetrack_env for training)
env.configure({
"collision_reward": -1.5,
"lane_centering_cost": 4,
"lane_centering_reward": 3,
"reward_speed_range": [10, 30],
"high_speed_reward": 1.5,
"action_reward": -0.5,
"screen_width": 600,
"show_trajectories": False,
"screen_heigth": 600
})
env.configure({
"action": {
"type": "ContinuousAction",
"longitudinal": True,
"lateral": True,
"target_speeds": [0, 5, 10]
},
"policy_frequency": 10
})
# for manual control
if MANUAL:
env.config["manual_control"] = True
#apply changes
#env.reset()
return env
def ConfigureMultiAgent(env,agent_num):
#configure several agents
env.configure({ "controlled_vehicles": agent_num })
#get config for one agent
action_config = env.config["action"]
obs_config = env.config["observation"]
#multi-action confige2
env.configure({
"action": {
"type": "MultiAgentAction",
"action_config": action_config
},
})
#config multi-observation
env.configure({
"observation": {
"type": "MultiAgentObservation",
"observation_config": obs_config
}
})
if __name__ == '__main__':
n_cpu = os.cpu_count() - 1
#env = CreateEnv()
env = make_vec_env("racetrack-v0", n_envs=n_cpu, vec_env_cls=SubprocVecEnv)
# If TRAIN, create new model and train it
if TRAIN:
batch_size = 64
if not USE_PREVIOUS_MODEL:
model = A2C("MlpPolicy",
env,
policy_kwargs=dict(net_arch=[dict(pi=[256, 256], vf=[256, 256])]),
n_steps=batch_size * 12 // n_cpu,
#batch_size=batch_size,
#n_epochs=10,
#learning_rate=7e-4,
#gamma=0.9,
verbose=3,
tensorboard_log="racetrack_a2c/")
else:
# for further training of previous model
model = A2C.load("racetrack_a2c/model_a2c", env=env)
# Train the model
model.learn(total_timesteps=int(1e7))
model.save("racetrack_a2c/model_a2c")
del model
# Run the algorithm
model = A2C.load("racetrack_a2c/model_a2c", env=env)
# dl racetrack as baseline
env = CreateEnv()
ConfigureMultiAgent(env, 1)
env = RecordVideo(env, video_folder="racetrack_a2c/videos", episode_trigger=lambda e: True)
env.unwrapped.set_record_video_wrapper(env)
print(env.config)
done = truncated = False
obs, info = env.reset()
print("number of obs: ",len(obs))
while not (done or truncated):
# Predict
# Dispatch the observations to the model to get the tuple of actions
actions = tuple(model.predict(obs_i)[0] for obs_i in obs)
#actions, _states = model.predict(obs, deterministic=True)
# Execute the actions
obs, reward, done, truncated, info = env.step(actions)
# Render
env.render()
env.close()