-
Notifications
You must be signed in to change notification settings - Fork 25
/
train_ppo_diffusion_agent.py
483 lines (452 loc) · 21.6 KB
/
train_ppo_diffusion_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
"""
DPPO fine-tuning.
"""
import os
import pickle
import einops
import numpy as np
import torch
import logging
import wandb
import math
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.finetune.train_ppo_agent import TrainPPOAgent
from util.scheduler import CosineAnnealingWarmupRestarts
class TrainPPODiffusionAgent(TrainPPOAgent):
def __init__(self, cfg):
super().__init__(cfg)
# Reward horizon --- always set to act_steps for now
self.reward_horizon = cfg.get("reward_horizon", self.act_steps)
# Eta - between DDIM (=0 for eval) and DDPM (=1 for training)
self.learn_eta = self.model.learn_eta
if self.learn_eta:
self.eta_update_interval = cfg.train.eta_update_interval
self.eta_optimizer = torch.optim.AdamW(
self.model.eta.parameters(),
lr=cfg.train.eta_lr,
weight_decay=cfg.train.eta_weight_decay,
)
self.eta_lr_scheduler = CosineAnnealingWarmupRestarts(
self.eta_optimizer,
first_cycle_steps=cfg.train.eta_lr_scheduler.first_cycle_steps,
cycle_mult=1.0,
max_lr=cfg.train.eta_lr,
min_lr=cfg.train.eta_lr_scheduler.min_lr,
warmup_steps=cfg.train.eta_lr_scheduler.warmup_steps,
gamma=1.0,
)
def run(self):
# Start training loop
timer = Timer()
run_results = []
cnt_train_step = 0
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.itr % self.render_freq == 0 and self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
last_itr_eval = eval_mode
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
else:
# if done at the end of last iteration, the envs are just reset
firsts_trajs[0] = done_venv
# Holder
obs_trajs = {
"state": np.zeros(
(self.n_steps, self.n_envs, self.n_cond_step, self.obs_dim)
)
}
chains_trajs = np.zeros(
(
self.n_steps,
self.n_envs,
self.model.ft_denoising_steps + 1,
self.horizon_steps,
self.action_dim,
)
)
terminated_trajs = np.zeros((self.n_steps, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))
if self.save_full_observations: # state-only
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
)
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(
cond=cond,
deterministic=eval_mode,
return_chain=True,
)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act
chains_venv = (
samples.chains.cpu().numpy()
) # n_env x denoising x horizon x act
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
(
obs_venv,
reward_venv,
terminated_venv,
truncated_venv,
info_venv,
) = self.venv.step(action_venv)
done_venv = terminated_venv | truncated_venv
if self.save_full_observations: # state-only
obs_full_venv = np.array(
[info["full_obs"]["state"] for info in info_venv]
) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
)
obs_trajs["state"][step] = prev_obs_venv["state"]
chains_trajs[step] = chains_venv
reward_trajs[step] = reward_venv
terminated_trajs[step] = terminated_venv
firsts_trajs[step + 1] = done_venv
# update for next step
prev_obs_venv = obs_venv
# count steps --- not acounting for done within action chunk
cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
if (
self.furniture_sparse_reward
): # only for furniture tasks, where reward only occurs in one env step
episode_best_reward = episode_reward
else:
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode:
with torch.no_grad():
obs_trajs["state"] = (
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten()
values_trajs = np.vstack(
(values_trajs, values.reshape(-1, self.n_envs))
)
chains_t = einops.rearrange(
torch.from_numpy(chains_trajs).float().to(self.device),
"s e t h d -> (s e) t h d",
)
chains_ts = torch.split(chains_t, self.logprob_batch_size, dim=0)
logprobs_trajs = np.empty(
(
0,
self.model.ft_denoising_steps,
self.horizon_steps,
self.action_dim,
)
)
for obs, chains in zip(obs_ts, chains_ts):
logprobs = self.model.get_logprobs(obs, chains).cpu().numpy()
logprobs_trajs = np.vstack(
(
logprobs_trajs,
logprobs.reshape(-1, *logprobs_trajs.shape[1:]),
)
)
# normalize reward with running variance if specified
if self.reward_scale_running:
reward_trajs_transpose = self.running_reward_scaler(
reward=reward_trajs.T, first=firsts_trajs[:-1].T
)
reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not terminal - apply reward scaling with constant if specified
obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
advantages_trajs = np.zeros_like(reward_trajs)
lastgaelam = 0
for t in reversed(range(self.n_steps)):
if t == self.n_steps - 1:
nextvalues = (
self.model.critic(obs_venv_ts)
.reshape(1, -1)
.cpu()
.numpy()
)
else:
nextvalues = values_trajs[t + 1]
nonterminal = 1.0 - terminated_trajs[t]
# delta = r + gamma*V(st+1) - V(st)
delta = (
reward_trajs[t] * self.reward_scale_const
+ self.gamma * nextvalues * nonterminal
- values_trajs[t]
)
# A = delta_t + gamma*lamdba*delta_{t+1} + ...
advantages_trajs[t] = lastgaelam = (
delta
+ self.gamma * self.gae_lambda * nonterminal * lastgaelam
)
returns_trajs = advantages_trajs + values_trajs
# k for environment step
obs_k = {
"state": einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
}
chains_k = einops.rearrange(
torch.tensor(chains_trajs, device=self.device).float(),
"s e t h d -> (s e) t h d",
)
returns_k = (
torch.tensor(returns_trajs, device=self.device).float().reshape(-1)
)
values_k = (
torch.tensor(values_trajs, device=self.device).float().reshape(-1)
)
advantages_k = (
torch.tensor(advantages_trajs, device=self.device)
.float()
.reshape(-1)
)
logprobs_k = torch.tensor(logprobs_trajs, device=self.device).float()
# Update policy and critic
total_steps = self.n_steps * self.n_envs * self.model.ft_denoising_steps
clipfracs = []
for update_epoch in range(self.update_epochs):
# for each epoch, go through all data in batches
flag_break = False
inds_k = torch.randperm(total_steps, device=self.device)
num_batch = max(1, total_steps // self.batch_size) # skip last ones
for batch in range(num_batch):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
batch_inds_b, denoising_inds_b = torch.unravel_index(
inds_b,
(self.n_steps * self.n_envs, self.model.ft_denoising_steps),
)
obs_b = {"state": obs_k["state"][batch_inds_b]}
chains_prev_b = chains_k[batch_inds_b, denoising_inds_b]
chains_next_b = chains_k[batch_inds_b, denoising_inds_b + 1]
returns_b = returns_k[batch_inds_b]
values_b = values_k[batch_inds_b]
advantages_b = advantages_k[batch_inds_b]
logprobs_b = logprobs_k[batch_inds_b, denoising_inds_b]
# get loss
(
pg_loss,
entropy_loss,
v_loss,
clipfrac,
approx_kl,
ratio,
bc_loss,
eta,
) = self.model.loss(
obs_b,
chains_prev_b,
chains_next_b,
denoising_inds_b,
returns_b,
values_b,
advantages_b,
logprobs_b,
use_bc_loss=self.use_bc_loss,
reward_horizon=self.reward_horizon,
)
loss = (
pg_loss
+ entropy_loss * self.ent_coef
+ v_loss * self.vf_coef
+ bc_loss * self.bc_loss_coeff
)
clipfracs += [clipfrac]
# update policy and critic
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
if self.learn_eta:
self.eta_optimizer.zero_grad()
loss.backward()
if self.itr >= self.n_critic_warmup_itr:
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.actor_ft.parameters(), self.max_grad_norm
)
self.actor_optimizer.step()
if self.learn_eta and batch % self.eta_update_interval == 0:
self.eta_optimizer.step()
self.critic_optimizer.step()
log.info(
f"approx_kl: {approx_kl}, update_epoch: {update_epoch}, num_batch: {num_batch}"
)
# Stop gradient update if KL difference reaches target
if self.target_kl is not None and approx_kl > self.target_kl:
flag_break = True
break
if flag_break:
break
# Explained variation of future rewards using value function
y_pred, y_true = values_k.cpu().numpy(), returns_k.cpu().numpy()
var_y = np.var(y_true)
explained_var = (
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Plot state trajectories (only in D3IL)
if (
self.itr % self.render_freq == 0
and self.n_render > 0
and self.traj_plotter is not None
):
self.traj_plotter(
obs_full_trajs=obs_full_trajs,
n_render=self.n_render,
max_episode_steps=self.max_episode_steps,
render_dir=self.render_dir,
itr=self.itr,
)
# Update lr, min_sampling_std
if self.itr >= self.n_critic_warmup_itr:
self.actor_lr_scheduler.step()
if self.learn_eta:
self.eta_lr_scheduler.step()
self.critic_lr_scheduler.step()
self.model.step()
diffusion_min_sampling_std = self.model.get_min_sampling_denoising_std()
# Save model
if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
self.save_model()
# Log loss and save metrics
run_results.append(
{
"itr": self.itr,
"step": cnt_train_step,
}
)
if self.save_trajs:
run_results[-1]["obs_full_trajs"] = obs_full_trajs
run_results[-1]["obs_trajs"] = obs_trajs
run_results[-1]["chains_trajs"] = chains_trajs
run_results[-1]["reward_trajs"] = reward_trajs
if self.itr % self.log_freq == 0:
time = timer()
run_results[-1]["time"] = time
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"success rate - eval": success_rate,
"avg episode reward - eval": avg_episode_reward,
"avg best reward - eval": avg_best_reward,
"num episode - eval": num_episode_finished,
},
step=self.itr,
commit=False,
)
run_results[-1]["eval_success_rate"] = success_rate
run_results[-1]["eval_episode_reward"] = avg_episode_reward
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: step {cnt_train_step:8d} | loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"total env step": cnt_train_step,
"loss": loss,
"pg loss": pg_loss,
"value loss": v_loss,
"bc loss": bc_loss,
"eta": eta,
"approx kl": approx_kl,
"ratio": ratio,
"clipfrac": np.mean(clipfracs),
"explained variance": explained_var,
"avg episode reward - train": avg_episode_reward,
"num episode - train": num_episode_finished,
"diffusion - min sampling std": diffusion_min_sampling_std,
"actor lr": self.actor_optimizer.param_groups[0]["lr"],
"critic lr": self.critic_optimizer.param_groups[0][
"lr"
],
},
step=self.itr,
commit=True,
)
run_results[-1]["train_episode_reward"] = avg_episode_reward
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1