-
Notifications
You must be signed in to change notification settings - Fork 26
/
policy.py
98 lines (93 loc) · 4.33 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2023-05-18 22:41:01
LastEditor: JiangJi
LastEditTime: 2023-05-18 23:18:27
Discription:
'''
import torch
import torch.nn as nn
import math,random
import numpy as np
from algos.base.policies import BasePolicy
from algos.base.networks import QNetwork
class Policy(BasePolicy):
def __init__(self,cfg) -> None:
super(Policy, self).__init__(cfg)
self.cfg = cfg
self.device = torch.device(cfg.device)
self.gamma = cfg.gamma
# e-greedy parameters
self.sample_count = None
self.epsilon_start = cfg.epsilon_start
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.batch_size = cfg.batch_size
self.target_update = cfg.target_update
self.create_graph() # create graph and optimizer
self.create_summary() # create summary
def create_graph(self):
self.state_size, self.action_size = self.get_state_action_size()
self.policy_net = QNetwork(self.cfg, self.state_size, self.action_size).to(self.device)
self.target_net = QNetwork(self.cfg, self.state_size, self.action_size).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict()) # or use this to copy parameters
self.create_optimizer()
def sample_action(self, state, **kwargs):
''' sample action
'''
# epsilon must decay(linear,exponential and etc.) for balancing exploration and exploitation
self.sample_count = kwargs.get('sample_count')
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay)
if random.random() > self.epsilon:
action = self.predict_action(state)
else:
action = self.action_space.sample()
return action
def predict_action(self,state, **kwargs):
''' predict action
'''
with torch.no_grad():
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float32).unsqueeze(dim=0)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item() # choose action corresponding to the maximum q value
return action
def update_data_after_learn(self):
self.data_after_train = {'idxs':self.idxs,'td_errors':self.td_errors}
def train(self, **kwargs):
''' update policy
'''
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones')
self.idxs = kwargs.get('idxs')
weights = kwargs.get('weights')
update_step = kwargs.get('update_step')
# convert numpy to tensor
states = torch.tensor(states, device=self.device, dtype=torch.float32)
actions = torch.tensor(actions, device=self.device, dtype=torch.int64).unsqueeze(dim=1)
next_states = torch.tensor(next_states, device=self.device, dtype=torch.float32)
rewards = torch.tensor(rewards, device=self.device, dtype=torch.float32).unsqueeze(dim=1)
dones = torch.tensor(dones, device=self.device, dtype=torch.float32).unsqueeze(dim=1)
weights = torch.tensor(weights, device=self.device, dtype=torch.float32).unsqueeze(dim=1)
# compute current Q values
q_values = self.policy_net(states).gather(1, actions)
# compute next max q value
next_q_values = self.target_net(next_states).max(1)[0].unsqueeze(dim=1)
# compute target Q values
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# compute loss
self.loss = (weights * nn.MSELoss()(q_values, target_q_values)).mean()
self.td_errors = torch.abs(q_values - target_q_values).cpu().detach().numpy() # shape(batchsize,1)
self.optimizer.zero_grad()
self.loss.backward()
# clip to avoid gradient explosion
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
# update target net every C steps
if update_step % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
self.update_data_after_learn()
self.update_summary() # update summary