Skip to content

Commit

Permalink
Refactor speed benchmark scripts (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Feb 24, 2023
1 parent 2237454 commit 942c741
Show file tree
Hide file tree
Showing 25 changed files with 200 additions and 2,051 deletions.
1 change: 1 addition & 0 deletions speed_benchmark/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m pip install pettingzoo open_spiel tianshou pygame cloudpickle chess pgx
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import argparse
import json
import collections
import time
import numpy as np
import pyspiel
from open_spiel.python.rl_environment import Environment, ChanceEventSampler

# Copied from https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/vector_env.py

# SyncVectorEnv is copied from
# https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/vector_env.py
#
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,9 +23,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A vectorized RL Environment."""


class SyncVectorEnv(object):
"""A vectorized RL Environment.
This environment is synchronized - games do not execute in parallel. Speedups
Expand Down Expand Up @@ -75,4 +81,47 @@ def reset(self, envs_to_reset=None):
if envs_to_reset[i] else self.envs[i].get_time_step()
for i in range(len(self.envs))
]
return time_steps
return time_steps


def make_single_env(env_name: str, seed: int):
def gen_env():
game = pyspiel.load_game(env_name)
return Environment(game, chance_event_sampler=ChanceEventSampler(seed=seed))
return gen_env()


def make_env(env_name: str, n_envs: int, seed: int) -> SyncVectorEnv:
return SyncVectorEnv([make_single_env(env_name, seed + i) for i in range(n_envs)])


def random_play(env: SyncVectorEnv, n_steps_lim: int, batch_size: int):
# random play for open spiel
StepOutput = collections.namedtuple("step_output", ["action"])
time_step = env.reset()
assert len(env.envs) == len(time_step) # ensure parallerization
rng = np.random.default_rng()
step_num = 0
while step_num < n_steps_lim:
# See https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/examples/rl_example.py
actions = [rng.choice(ts.observations["legal_actions"][ts.observations["current_player"]]) for ts in time_step]
step_outputs = [StepOutput(action=action) for action in actions]
time_step, reward, done, unreset_time_steps = env.step(step_outputs, reset_if_done=True)
step_num += batch_size
return step_num


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("env_name") # go, chess backgammon tic_tac_toe
parser.add_argument("batch_size", type=int)
parser.add_argument("n_steps_lim", default=2 ** 10 * 10, type=int)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
assert args.n_steps_lim % args.batch_size == 0
env = make_env(args.env_name, args.batch_size, args.seed)
time_sta = time.time()
steps_num = random_play(env, args.n_steps_lim, args.batch_size)
time_end = time.time()
sec = time_end-time_sta
json.dumps({"game": args.env_name, "venv": "for-loop", "library": "open_spiel", "total_steps": steps_num, "total_sec": sec, "steps/sec": steps_num/sec, "batch_size": args.batch_size})
Original file line number Diff line number Diff line change
@@ -1,37 +1,26 @@
"""
Copied from TienShou repository:
https://github.com/thu-ml/tianshou/blob/master/tianshou/env/pettingzoo_env.py
Distributed under MIT LICENSE:
https://github.com/thu-ml/tianshou/blob/master/LICENSE
Modified to use OpenSpiel in SubprocVecEnv (see #384 for changes)
"""



import json
from tianshou.env import SubprocVectorEnv
import numpy as np
import time
import argparse
import warnings
from abc import ABC
from typing import Any, Dict, List, Tuple

import pettingzoo
from gymnasium import spaces
from packaging import version
import pyspiel
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper
from open_spiel.python.rl_environment import Environment, ChanceEventSampler


if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
warnings.warn(
f"You are using PettingZoo {pettingzoo.__version__}. "
f"Future tianshou versions may not support PettingZoo<1.21.0. "
f"Consider upgrading your PettingZoo version.", DeprecationWarning
)


# OpenSpielEnv is modified from TianShou repository (see #384 for changes):
# This wrapper enables to use TianShou's SubprocVectorEnv for OpenSpiel
#
# https://github.com/thu-ml/tianshou/blob/master/tianshou/env/pettingzoo_env.py
#
# Distributed under MIT LICENSE:
#
# https://github.com/thu-ml/tianshou/blob/master/LICENSE
class OpenSpielEnv(AECEnv, ABC):
"""The interface for petting zoo environments.
Expand Down Expand Up @@ -100,3 +89,41 @@ def seed(self, seed: Any = None) -> None:

def render(self) -> Any:
return self.env.render()


def make_single_env(env_name: str, seed: int):
def gen_env():
game = pyspiel.load_game(env_name)
return Environment(game, chance_event_sampler=ChanceEventSampler(seed=seed))
return gen_env()

def make_env(env_name: str, n_envs: int, seed: int):
return SubprocVectorEnv([lambda: OpenSpielEnv(make_single_env(env_name, seed)) for _ in range(n_envs)])


def random_play(env: SubprocVectorEnv, n_steps_lim: int, batch_size: int):
step_num = 0
rng = np.random.default_rng()
observation, info = env.reset()
while step_num < n_steps_lim:
legal_action_mask = [observation[i]["mask"] for i in range(len(observation))]
action = [rng.choice(legal_action_mask[i]) for i in range(len(legal_action_mask))] # chose action randomly
observation, reward, terminated, _, info = env.step(action)
step_num += batch_size
return step_num


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("env_name")
parser.add_argument("batch_size", type=int)
parser.add_argument("n_steps_lim", default=2 ** 10 * 10, type=int)
parser.add_argument("--seed", default=100, type=bool)
args = parser.parse_args()
assert args.n_steps_lim % args.batch_size == 0
env = make_env(args.env_name, args.batch_size, args.seed)
time_sta = time.time()
steps_num = random_play(env, args.n_steps_lim, args.batch_size)
time_end = time.time()
sec = time_end - time_sta
print(json.dumps({"game": args.env_name, "venv": "subproc", "library": "open_spiel", "total_steps": steps_num, "total_sec": sec, "steps/sec": steps_num/sec, "batch_size": args.batch_size}))
77 changes: 77 additions & 0 deletions speed_benchmark/run_petting_zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse
import time
import json
import numpy as np
import collections
from tianshou.env.pettingzoo_env import PettingZooEnv


class AutoResetPettingZooEnv(PettingZooEnv):
def __init__(self, env):
super().__init__(env)

def step(self, action):
obs, reward, term, trunc, info = super().step(action)
if term:
obs = super().reset()
return obs, reward, term, trunc, info


def make_env(env_name, n_envs, vec_env):

def get_go_env():
from pettingzoo.classic.go import go
return AutoResetPettingZooEnv(go.env())

def get_tictactoe_env():
from pettingzoo.classic.tictactoe import tictactoe
return AutoResetPettingZooEnv(tictactoe.env())

def get_chess_env():
from pettingzoo.classic.chess import chess
return AutoResetPettingZooEnv(chess.env())

if vec_env == "for-loop":
from tianshou.env import DummyVectorEnv as VecEnv
elif vec_env == "subproc":
from tianshou.env import SubprocVectorEnv as VecEnv

if env_name == "go":
env_fn = get_go_env
elif env_name == "tic_tac_toe":
env_fn = get_tictactoe_env
elif env_name == "chess":
env_fn = get_chess_env

return VecEnv([env_fn for _ in range(n_envs)])


def random_play(env, n_steps_lim: int, batch_size: int) -> int:
# petting zooのgo環境でrandom gaentを終局まで動かす.
step_num = 0
rng = np.random.default_rng()
observation = env.reset()
assert len(env._env_fns) == len(observation) # ensure parallerization
while step_num < n_steps_lim:
legal_action_mask = [observation[i]["mask"] for i in range(batch_size)]
action = [rng.choice(np.where(legal_action_mask[i])[0]) for i in range(batch_size)] # chose action randomly
observation, reward, terminated, _, _ = env.step(action)
step_num += batch_size
return step_num


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("env_name") # go, chess, tic_tac_toe
parser.add_argument("venv") # for-loop, subproc
parser.add_argument("batch_size", type=int)
parser.add_argument("n_steps_lim", default=2 ** 10 * 10, type=int)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
assert args.n_steps_lim % args.batch_size == 0
env = make_env(args.env_name, args.batch_size, args.venv)
time_sta = time.time()
steps_num = random_play(env, args.n_steps_lim, args.batch_size)
time_end = time.time()
sec = time_end - time_sta
print(json.dumps({"game": args.env_name, "venv": args.venv, "library": "open_spiel", "total_steps": steps_num, "total_sec": sec, "steps/sec": steps_num/sec, "batch_size": args.batch_size}))
37 changes: 16 additions & 21 deletions workspace/speed_benchmark.py → speed_benchmark/run_pgx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import json
import jax
import pgx
from pgx.utils import act_randomly
Expand Down Expand Up @@ -37,28 +38,22 @@ def benchmark(env_id: pgx.EnvId, batch_size, num_steps=(2 ** 12) * 1000):
state = step(state, action)
te = time.time()

return f"{num_steps / (te - ts):.05f}"
return num_steps, te - ts


N = (2 ** 12) * 100
print(f"Total # of steps: {N}")
bs_list = [2 ** i for i in range(5, 13)]
print("| env_id |" + "|".join([str(bs) for bs in bs_list]) + "|")
print("|:---:|" + "|".join([":---:" for bs in bs_list]) + "|")
for env_id in get_args(pgx.EnvId):
s = f"|{env_id}|"
for bs in tqdm(bs_list, leave=False):
s += benchmark(env_id, bs, N)
s += "|"
print(s)
games = {
"tic_tac_toe": "tic_tac_toe/v0",
"backgammon": "backgammon/v0",
"shogi": "shogi/v0",
"go": "go-19x19/v0",
}


"""
Total # of steps: 409600
| env_id |32|64|128|256|512|1024|2048|4096|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|tic_tac_toe/v0|25193.54921|51203.80125|99197.43688|206175.78196|413948.81221|723250.23824|1664977.36893|3265886.46947|
|go-19x19/v0|15146.45769|29891.62027|58064.94882|108400.19704|173638.40814|286740.95368|379331.88909|449555.32632|
|shogi/v0|21047.48879|42130.05279|82988.81210|175415.01266|259940.79393|290410.20642|299800.69880|308552.26434|
|minatar/asterix/v0|13066.68075|25836.00751|52134.46018|102929.03752|205880.59846|384825.85566|806843.53100|1553951.76960|
"""
N = 2 ** 10 * 10
bs_list = [2 ** i for i in range(1, 11)]
d = {}
for game, env_id in games.items():
for bs in bs_list:
num_steps, sec = benchmark(env_id, bs, N)
print(json.dumps({"game": game, "library": "pgx",
"total_steps": num_steps, "total_sec": sec, "steps/sec": num_steps / sec, "batch_size": bs}))
6 changes: 0 additions & 6 deletions workspace/compariosn/compare.sh

This file was deleted.

Loading

0 comments on commit 942c741

Please sign in to comment.