-
Notifications
You must be signed in to change notification settings - Fork 195
/
visualization_demo.py
209 lines (186 loc) · 7.3 KB
/
visualization_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A demo of Graphviz visualization of a search tree."""
from typing import Optional, Sequence
from absl import app
from absl import flags
import chex
import jax
import jax.numpy as jnp
import mctx
import pygraphviz
FLAGS = flags.FLAGS
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("num_simulations", 32, "Number of simulations.")
flags.DEFINE_integer("max_num_considered_actions", 16,
"The maximum number of actions expanded at the root.")
flags.DEFINE_integer("max_depth", None, "The maximum search depth.")
flags.DEFINE_string("output_file", "/tmp/search_tree.png",
"The output file for the visualization.")
def convert_tree_to_graph(
tree: mctx.Tree,
action_labels: Optional[Sequence[str]] = None,
batch_index: int = 0
) -> pygraphviz.AGraph:
"""Converts a search tree into a Graphviz graph.
Args:
tree: A `Tree` containing a batch of search data.
action_labels: Optional labels for edges, defaults to the action index.
batch_index: Index of the batch element to plot.
Returns:
A Graphviz graph representation of `tree`.
"""
chex.assert_rank(tree.node_values, 2)
batch_size = tree.node_values.shape[0]
if action_labels is None:
action_labels = range(tree.num_actions)
elif len(action_labels) != tree.num_actions:
raise ValueError(
f"action_labels {action_labels} has the wrong number of actions "
f"({len(action_labels)}). "
f"Expecting {tree.num_actions}.")
def node_to_str(node_i, reward=0, discount=1):
return (f"{node_i}\n"
f"Reward: {reward:.2f}\n"
f"Discount: {discount:.2f}\n"
f"Value: {tree.node_values[batch_index, node_i]:.2f}\n"
f"Visits: {tree.node_visits[batch_index, node_i]}\n")
def edge_to_str(node_i, a_i):
node_index = jnp.full([batch_size], node_i)
probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i])
return (f"{action_labels[a_i]}\n"
f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations
f"p: {probs[a_i]:.2f}\n")
graph = pygraphviz.AGraph(directed=True)
# Add root
graph.add_node(0, label=node_to_str(node_i=0), color="green")
# Add all other nodes and connect them up.
for node_i in range(tree.num_simulations):
for a_i in range(tree.num_actions):
# Index of children, or -1 if not expanded
children_i = tree.children_index[batch_index, node_i, a_i]
if children_i >= 0:
graph.add_node(
children_i,
label=node_to_str(
node_i=children_i,
reward=tree.children_rewards[batch_index, node_i, a_i],
discount=tree.children_discounts[batch_index, node_i, a_i]),
color="red")
graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i))
return graph
def _run_demo(rng_key: chex.PRNGKey):
"""Runs a search algorithm on a toy environment."""
# We will define a deterministic toy environment.
# The deterministic `transition_matrix` has shape `[num_states, num_actions]`.
# The `transition_matrix[s, a]` holds the next state.
transition_matrix = jnp.array([
[1, 2, 3, 4],
[0, 5, 0, 0],
[0, 0, 0, 6],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
], dtype=jnp.int32)
# The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]`
# holds the reward for that (s, a) pair.
rewards = jnp.array([
[1, -1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[10, 0, 20, 0],
], dtype=jnp.float32)
num_states = rewards.shape[0]
# The discount for each (s, a) pair.
discounts = jnp.where(transition_matrix > 0, 1.0, 0.0)
# Using optimistic initial values to encourage exploration.
values = jnp.full([num_states], 15.0)
# The prior policies for each state.
all_prior_logits = jnp.zeros_like(rewards)
root, recurrent_fn = _make_batched_env_model(
# Using batch_size=2 to test the batched search.
batch_size=2,
transition_matrix=transition_matrix,
rewards=rewards,
discounts=discounts,
values=values,
prior_logits=all_prior_logits)
# Running the search.
policy_output = mctx.gumbel_muzero_policy(
params=(),
rng_key=rng_key,
root=root,
recurrent_fn=recurrent_fn,
num_simulations=FLAGS.num_simulations,
max_depth=FLAGS.max_depth,
max_num_considered_actions=FLAGS.max_num_considered_actions,
)
return policy_output
def _make_batched_env_model(
batch_size: int,
*,
transition_matrix: chex.Array,
rewards: chex.Array,
discounts: chex.Array,
values: chex.Array,
prior_logits: chex.Array):
"""Returns a batched `(root, recurrent_fn)`."""
chex.assert_equal_shape([transition_matrix, rewards, discounts,
prior_logits])
num_states, num_actions = transition_matrix.shape
chex.assert_shape(values, [num_states])
# We will start the search at state zero.
root_state = 0
root = mctx.RootFnOutput(
prior_logits=jnp.full([batch_size, num_actions],
prior_logits[root_state]),
value=jnp.full([batch_size], values[root_state]),
# The embedding will hold the state index.
embedding=jnp.zeros([batch_size], dtype=jnp.int32),
)
def recurrent_fn(params, rng_key, action, embedding):
del params, rng_key
chex.assert_shape(action, [batch_size])
chex.assert_shape(embedding, [batch_size])
recurrent_fn_output = mctx.RecurrentFnOutput(
reward=rewards[embedding, action],
discount=discounts[embedding, action],
prior_logits=prior_logits[embedding],
value=values[embedding])
next_embedding = transition_matrix[embedding, action]
return recurrent_fn_output, next_embedding
return root, recurrent_fn
def main(_):
rng_key = jax.random.PRNGKey(FLAGS.seed)
jitted_run_demo = jax.jit(_run_demo)
print("Starting search.")
policy_output = jitted_run_demo(rng_key)
batch_index = 0
selected_action = policy_output.action[batch_index]
q_value = policy_output.search_tree.summary().qvalues[
batch_index, selected_action]
print("Selected action:", selected_action)
# To estimate the value of the root state, use the Q-value of the selected
# action. The Q-value is not affected by the exploration at the root node.
print("Selected action Q-value:", q_value)
graph = convert_tree_to_graph(policy_output.search_tree)
print("Saving tree diagram to:", FLAGS.output_file)
graph.draw(FLAGS.output_file, prog="dot")
if __name__ == "__main__":
app.run(main)