Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sotetsuk/surrpass-cod…
Browse files Browse the repository at this point in the history
…ecov
  • Loading branch information
sotetsuk committed Nov 11, 2023
2 parents 01764df + 478abb3 commit 5d91f2d
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 1 deletion.
265 changes: 265 additions & 0 deletions colab/mcts_connect_four.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOAE60e+jLXNtx01Ej6BBgo",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sotetsuk/pgx/blob/sotetsuk%2Fcolab%2Fc4-mcts/colab/mcts_connect_four.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Simple MCTS example with MCTX\n",
"\n",
"Inspired by https://github.com/Carbon225/mctx-classic"
],
"metadata": {
"id": "125IU6vLfn_z"
}
},
{
"cell_type": "code",
"source": [
"!pip install pgx mctx"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aPYOMRQZamsk",
"outputId": "9d78b714-6f72-4cc6-e9ee-d62e89664511"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting pgx\n",
" Downloading pgx-2.0.1-py3-none-any.whl (412 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m412.5/412.5 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting mctx\n",
" Downloading mctx-0.0.3-py3-none-any.whl (44 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: jax>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from pgx) (0.4.20)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pgx) (4.5.0)\n",
"Collecting svgwrite (from pgx)\n",
" Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.1/67.1 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting pgx-minatar==0.5.1 (from pgx)\n",
" Downloading pgx_minatar-0.5.1-py3-none-any.whl (39 kB)\n",
"Requirement already satisfied: chex>=0.0.8 in /usr/local/lib/python3.10/dist-packages (from mctx) (0.1.7)\n",
"Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.10/dist-packages (from mctx) (0.4.20+cuda11.cudnn86)\n",
"Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (1.4.0)\n",
"Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (0.1.8)\n",
"Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (1.23.5)\n",
"Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (0.12.0)\n",
"Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.1->pgx) (0.2.0)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.1->pgx) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.1->pgx) (1.11.3)\n",
"Installing collected packages: pgx-minatar, svgwrite, pgx, mctx\n",
"Successfully installed mctx-0.0.3 pgx-2.0.1 pgx-minatar-0.5.1 svgwrite-1.4.3\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "PSIbeBEzage1"
},
"outputs": [],
"source": [
"from typing import NamedTuple\n",
"from functools import partial\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import chex\n",
"import mctx\n",
"\n",
"import pgx\n",
"from pgx.experimental import act_randomly\n",
"\n",
"from IPython.display import *"
]
},
{
"cell_type": "code",
"source": [
"class Config(NamedTuple):\n",
" env_id: pgx.EnvId = \"connect_four\"\n",
" seed: int = 0\n",
" num_simulations: int = 1_000\n",
" batch_size: int = 1\n",
"\n",
"config = Config()\n",
"env = pgx.make(config.env_id)"
],
"metadata": {
"id": "OtFgSQZaa2lL"
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def policy_fn(legal_action_mask):\n",
" \"\"\"Return the logits of random policy. -Inf is set to illegal actions.\"\"\"\n",
" chex.assert_shape(legal_action_mask, (env.num_actions,))\n",
"\n",
" logits = legal_action_mask.astype(jnp.float32)\n",
" logits = jnp.where(legal_action_mask, logits, jnp.finfo(logits.dtype).min)\n",
" return logits\n",
"\n",
"\n",
"def value_fn(key, state):\n",
" \"\"\"Return the value based on random rollout.\"\"\"\n",
" chex.assert_rank(state.current_player, 0)\n",
"\n",
" def cond_fn(x):\n",
" state, key = x\n",
" return ~state.terminated\n",
"\n",
" def body_fn(x):\n",
" state, key = x\n",
" key, key_act, key_step = jax.random.split(key, 3)\n",
" action = act_randomly(key_act, state.legal_action_mask)\n",
" state = env.step(state, action, key_step)\n",
" return (state, key)\n",
"\n",
" current_player = state.current_player\n",
" state, _ = jax.lax.while_loop(cond_fn, body_fn, (state, key))\n",
" return state.rewards[current_player]\n",
"\n",
"\n",
"def recurrent_fn(params, rng_key, action, state):\n",
" del params\n",
" current_player = state.current_player\n",
" state = env.step(state, action)\n",
" logits = policy_fn(state.legal_action_mask)\n",
" value = value_fn(rng_key, state)\n",
" reward = state.rewards[current_player]\n",
" value = jax.lax.select(state.terminated, 0.0, value)\n",
" discount = jax.lax.select(state.terminated, 0.0, -1.0)\n",
" recurrent_fn_output = mctx.RecurrentFnOutput(\n",
" reward=reward,\n",
" discount=discount,\n",
" prior_logits=logits,\n",
" value=value,\n",
" )\n",
" return recurrent_fn_output, state\n",
"\n",
"\n",
"def run_mcts(key, state):\n",
" key, subkey = jax.random.split(key)\n",
" keys = jax.random.split(subkey, config.batch_size)\n",
" key, subkey = jax.random.split(key)\n",
"\n",
" root = mctx.RootFnOutput(\n",
" prior_logits=jax.vmap(policy_fn)(state.legal_action_mask),\n",
" value=jax.vmap(value_fn)(keys, state),\n",
" embedding=state\n",
" )\n",
" policy_output = mctx.muzero_policy(\n",
" params=None,\n",
" rng_key=subkey,\n",
" root=root,\n",
" invalid_actions=~state.legal_action_mask,\n",
" recurrent_fn=jax.vmap(recurrent_fn, in_axes=(None, None, 0, 0)),\n",
" num_simulations=config.num_simulations,\n",
" max_depth=env.observation_shape[0] * env.observation_shape[1], # 42 in connect four\n",
" qtransform=partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),\n",
" dirichlet_fraction=0.0\n",
" )\n",
" return policy_output"
],
"metadata": {
"id": "E0wJd8INbA9H"
},
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def vs_human(is_human_first=True):\n",
" assert config.batch_size == 1\n",
" key = jax.random.PRNGKey(config.seed)\n",
" init_fn = jax.jit(jax.vmap(env.init))\n",
" step_fn = jax.jit(jax.vmap(env.step))\n",
"\n",
" key, subkey = jax.random.split(key)\n",
" keys = jax.random.split(subkey, config.batch_size)\n",
" state: pgx.State = init_fn(keys)\n",
"\n",
" is_human_turn = is_human_first\n",
" while True:\n",
" clear_output(True)\n",
" pgx.save_svg(state, \"/tmp/state.svg\")\n",
" display_svg(SVG(\"/tmp/state.svg\"))\n",
" if state.terminated.all():\n",
" break\n",
" if is_human_turn:\n",
" action = int(input(\"Your action: \"))\n",
" action = jnp.int32([action])\n",
" else:\n",
" policy_output = jax.jit(run_mcts)(key, state)\n",
" action = policy_output.action\n",
" state = step_fn(state, action)\n",
" is_human_turn = not is_human_turn"
],
"metadata": {
"id": "YItP5A_ibnTf"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"source": [
"vs_human()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 301
},
"id": "EjfyCeo2e3KX",
"outputId": "d3b63767-7908-4b19-95f6-7a5fd16ba729"
},
"execution_count": 39,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" baseProfile=\"full\" height=\"280.0\" version=\"1.1\" width=\"280.0\"><defs/><rect fill=\"white\" height=\"245\" width=\"245\" x=\"0\" y=\"0\"/><g transform=\"scale(1.0)\"><rect fill=\"white\" height=\"280\" width=\"280\" x=\"0\" y=\"0\"/><g transform=\"translate(17.5,17.5)\"><g id=\"vline\" stroke=\"gray\"><line stroke-width=\"1px\" x1=\"35\" x2=\"35\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"70\" x2=\"70\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"105\" x2=\"105\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"140\" x2=\"140\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"175\" x2=\"175\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"210\" x2=\"210\" y1=\"0\" y2=\"210\"/></g><g id=\"vline\" stroke=\"gray\"><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"35\" y2=\"35\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"70\" y2=\"70\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"105\" y2=\"105\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"140\" y2=\"140\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"175\" y2=\"175\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"210\" y2=\"210\"/></g><rect fill=\"black\" height=\"6\" stroke=\"black\" width=\"245\" x=\"0\" y=\"210\"/><rect fill=\"black\" height=\"245\" stroke=\"black\" width=\"6\" x=\"-6\" y=\"0\"/><rect fill=\"black\" height=\"245\" stroke=\"black\" width=\"6\" x=\"245\" y=\"0\"/><circle cx=\"52.5\" cy=\"17.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"17.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"52.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"52.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"52.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"87.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"87.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"87.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"87.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"122.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"122.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"122.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"122.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"157.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"157.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"157.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"157.5\" cy=\"157.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"157.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"17.5\" cy=\"192.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"192.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"192.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"192.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"157.5\" cy=\"192.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"192.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/></g></g></svg>"
},
"metadata": {}
}
]
}
]
}
3 changes: 2 additions & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ jaxlib
pytest
matplotlib
ipython
dm-haiku
# hot fix. to avoid errors in Py3.8
dm-haiku==0.0.10
pytest-cov
pgx-minatar
black
Expand Down

0 comments on commit 5d91f2d

Please sign in to comment.