-
Notifications
You must be signed in to change notification settings - Fork 2
/
sarsa.py
73 lines (49 loc) · 1.88 KB
/
sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gym
import itertools
import matplotlib
import numpy as np
import pandas as pd
import sys
if "./gym-botenv/" not in sys.path:
sys.path.append("./gym-botenv/")
from collections import defaultdict
from gym_botenv.envs.botenv_env import BotenvEnv
from utils import plotting
env = BotenvEnv(1000)
def make_epsilon_greedy_policy(Q, epsilon, nA):
def policy_fn(observation):
A = np.ones(nA, dtype=float) * epsilon / nA
best_action = np.argmax(Q[observation])
A[best_action] += (1.0 - epsilon)
return A
return policy_fn
def sarsa(env, num_episodes, discount_factor=1.0, alpha=0.5, epsilon=0.1):
Q = defaultdict(lambda: np.zeros(env.nA))
stats = plotting.EpisodeStats(
episode_lengths=np.zeros(num_episodes),
episode_rewards=np.zeros(num_episodes)
)
policy = make_epsilon_greedy_policy(Q, epsilon, env.nA)
for i_episode in range(num_episodes):
print("\rEpisode {}/{}".format(i_episode + 1, num_episodes), end="")
sys.stdout.flush()
state = env.reset()
action_probs = policy(state)
action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
for t in itertools.count():
next_state, reward, done, _ = env.step(action)
next_action_probs = policy(next_state)
next_action = np.random.choice(np.arange(len(next_action_probs)), p=next_action_probs)
stats.episode_rewards[i_episode] += reward
# TD Update
td_target = reward + discount_factor * Q[next_state][next_action]
td_delta = td_target - Q[state][action]
Q[state][action] += alpha * td_delta
if done:
break
action = next_action
state = next_state
return Q, stats
if __name__ == '__main__':
Q, stats = sarsa(env, 1500)
plotting.plot_episode_stats(stats)