Skip to content

Commit

Permalink
Tabular Q solution example (#125)
Browse files Browse the repository at this point in the history
Adding Tabular Q solution as an example.
  • Loading branch information
JD-ETH authored Mar 20, 2021
1 parent c5c448c commit 6b9dbd1
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 4 deletions.
21 changes: 21 additions & 0 deletions examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ py_test(
],
)

py_binary(
name = "tabular_q",
srcs = ["tabular_q.py"],
deps = [
"//compiler_gym",
"//compiler_gym/util",
"//compiler_gym/util/flags:benchmark_from_flags",
],
)

py_test(
name = "tabular_q_test",
timeout = "short",
srcs = ["tabular_q_test.py"],
deps = [
":tabular_q",
"//compiler_gym/util",
"//tests:test_main",
],
)

py_binary(
name = "random_walk",
srcs = ["random_walk.py"],
Expand Down
13 changes: 9 additions & 4 deletions examples/brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def run_brute_force(
expected_chunk_count = math.ceil(expected_trial_count / chunksize)
chunk_count = 0
best_reward = -float("inf")

best_action_sequence = []
print(
f"Enumerating all episodes of {len(actions)} actions × {episode_length} steps"
)
Expand Down Expand Up @@ -262,15 +262,17 @@ def run_brute_force(
print(
f"\r\033[KRuntime: {humanize.naturaldelta(time() - started)}. "
f"Progress: {chunk_count/expected_chunk_count:.2%}. "
f"Best reward found: {best_reward:.4%}.",
f"Best reward found: {best_reward}.",
file=sys.stderr,
flush=True,
end="",
)
for actions, rewards in chunk:
print(*actions, *rewards, sep=",", file=f, flush=True)
if rewards and rewards[-1] is not None:
best_reward = max(best_reward, rewards[-1])
if sum(rewards) > best_reward:
best_reward = sum(rewards)
best_action_sequence = actions
except KeyboardInterrupt:
print("\nkeyboard interrupt", end="", flush=True)

Expand All @@ -288,11 +290,14 @@ def run_brute_force(
worker.join()

num_trials = sum(worker.num_trials for worker in workers)
env: CompilerEnv = make_env()
print(
f"completed {humanize.intcomma(num_trials)} of "
f"{humanize.intcomma(expected_trial_count)} trials "
f"({num_trials / expected_trial_count:.3%})"
f"({num_trials / expected_trial_count:.3%}), best sequence",
" ".join([env.action_space.flags[i] for i in best_action_sequence]),
)
env.close()


def main(argv):
Expand Down
204 changes: 204 additions & 0 deletions examples/tabular_q.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Simple compiler gym tabular q learning example.
Usage python tabular_q.py --benchmark=<benchmark>
Using selected features from Autophase observation space, given a specific training
program as gym environment, find the best action sequence using online q learning.
"""

import random
from typing import Dict, NamedTuple

import gym
from absl import app, flags

from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags
from compiler_gym.util.timer import Timer

flags.DEFINE_list(
"actions",
[
"-break-crit-edges",
"-early-cse-memssa",
"-gvn-hoist",
"-gvn",
"-instcombine",
"-instsimplify",
"-jump-threading",
"-loop-reduce",
"-loop-rotate",
"-loop-versioning",
"-mem2reg",
"-newgvn",
"-reg2mem",
"-simplifycfg",
"-sroa",
],
"A list of action names to explore from.",
)
flags.DEFINE_float("discount", 1.0, "The discount factor.")
flags.DEFINE_list(
"features_indices",
[19, 22, 51],
"Indices of Alphaphase features that are used to construct a state",
)
flags.DEFINE_float("learning_rate", 0.1, "learning rate of the q-learning.")
flags.DEFINE_integer("episodes", 5000, "number of episodes used to learn.")
flags.DEFINE_integer(
"log_every", 50, "number of episode interval where progress is reported."
)
flags.DEFINE_float("epsilon", 0.2, "Epsilon rate of exploration. ")
flags.DEFINE_integer("episode_length", 5, "The number of steps in each episode.")
FLAGS = flags.FLAGS


class StateActionTuple(NamedTuple):
"""An state action tuple used as q-table keys"""

autophase0: int
autophase1: int
autophase2: int
cur_step: int
action_index: int


def make_q_table_key(autophase_feature, action, step):
"""Create a hashable Q-table key.
For tabular learning we will be constructing a Q-table which maps a
(state, action) pair to an expected (remaining) reward. The purpose of this
function is to convert the (state, action) properties into a hashable tuple
that can be used as a key for a Q-table dictionary.
In the CompilerGym setup, encoding the true state the program is not obvious,
and this solution turns to use the observations from Autophase features instead.
The default arguments handpicked 3 indices from the Autophase feature that
appear to change a lot during optimization.
In addition, the current step in the episode is added to the state representation
as well. In the current fixed-episode-length setup, we need to differentiate
reaching a state at different steps, as they can lead to different final rewards,
depending on the remaining optimization steps.
Finally, we add the action index to the key.
"""

return StateActionTuple(
*autophase_feature[FLAGS.features_indices], step, FLAGS.actions.index(action)
)


def select_action(q_table, ob, step, epsilon=0.0):
qs = [q_table.get(make_q_table_key(ob, act, step), -1) for act in FLAGS.actions]
if random.random() < epsilon:
return random.choice(FLAGS.actions)
max_indices = [i for i, x in enumerate(qs) if x == max(qs)]
# Breaking ties at random by selecting any of the indices.
return FLAGS.actions[random.choice(max_indices)]


def get_max_q_value(q_table, ob, step):
max_q = 0
for act in FLAGS.actions:
hashed = make_q_table_key(ob, act, step)
max_q = max(q_table.get(hashed, 0), max_q)
return max_q


def rollout(qtable, env, printout=False):
# rollout the policy using a given Q table greedily.
observation = env.reset()
action_seq, rewards = [], []
for i in range(FLAGS.episode_length):
a = select_action(qtable, observation, i)
action_seq.append(a)
observation, reward, done, info = env.step(env.action_space.flags.index(a))
rewards.append(reward)
if printout:
print(
"Resulting sequence: ", ",".join(action_seq), f"total reward {sum(rewards)}"
)
return sum(rewards)


def train(q_table, env):
# Buffer an old version of q table to inspect training progress.
prev_q = {}

# Run the training process "online", where the policy evaluation and
# policy improvement happens directly after one another.
for i in range(1, FLAGS.episodes + 1):
current_length = 0
obs = env.reset()
while current_length < FLAGS.episode_length:
# Run epsilon greedy policy to allow exploration.
a = select_action(q_table, obs, current_length, FLAGS.epsilon)
hashed = make_q_table_key(obs, a, current_length)
if hashed not in q_table:
q_table[hashed] = 0
# Take a stap in the environment, record the reward and state transition.
# Effectively we are evaluating the policy by taking a step in the
# environment.
obs, reward, done, info = env.step(env.action_space.flags.index(a))
current_length += 1

# Compute the target value of the current state, by using the current
# step-reward and bootstrapping from the next state. In Q-learning,
# a greedy policy is implied by the Q-table, thus we can approximate
# the expected reward at the next state as the maximum value of
# all the associated state-action pair rewards (Q values). A discount
# can be used to emphasize on immediate early rewards, and encourage
# the agent to achieve higher rewards sooner than later.
target = reward + FLAGS.discount * get_max_q_value(
q_table, obs, current_length
)

# Update Q value. Instead of replacing the Q value at the current
# state action pair directly, a learning rate is introduced to interpolate
# between the current value and target value, effectively damping the
# changes. By updating the Q-table, we effectively updated the policy.
q_table[hashed] = (
FLAGS.learning_rate * target
+ (1 - FLAGS.learning_rate) * q_table[hashed]
)

if i % FLAGS.log_every == 0:

def compare_qs(q_old, q_new):
diff = [q_new[k] - v for k, v in q_old.items()]
return sum(diff) / len(diff) if diff else 0.0

difference = compare_qs(prev_q, q_table)
# Evaluate the current policy
cur_rewards = rollout(q_table, env)
print(
f"episode={i:4d}, cur_reward={cur_rewards:.5f}, Q-table_entries={len(q_table):5d}, Q-table_diff={difference:.7f}"
)
prev_q = q_table.copy()


def main(argv):
# Initialize a Q table.
q_table: Dict[StateActionTuple, float] = {}
benchmark = benchmark_from_flags()
assert benchmark, "You must specify a benchmark using the --benchmark flag"
env = gym.make("llvm-autophase-ic-v0", benchmark=benchmark)

try:
# Train a Q-table.
with Timer("Constructing Q-table"):
train(q_table, env)

# Rollout resulting policy.
rollout(q_table, env, True)

finally:
env.close()


if __name__ == "__main__":
app.run(main)
33 changes: 33 additions & 0 deletions examples/tabular_q_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tests for //compiler_gym/bin:tabular_q."""
from absl import flags

from compiler_gym.util.capture_output import capture_output
from examples.tabular_q import main
from tests.test_main import main as _test_main

FLAGS = flags.FLAGS


def test_run_tabular_q_smoke_test():
FLAGS.unparse_flags()
FLAGS(
[
"argv0",
"--episode_length=5",
"--episodes=10",
"--log_every=2",
"--benchmark=cBench-v1/crc32",
]
)
with capture_output() as out:
main(["argv0"])

assert "Resulting sequence" in out.stdout


if __name__ == "__main__":
_test_main()

0 comments on commit 6b9dbd1

Please sign in to comment.