-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into sotetsuk/surrpass-cod…
…ecov
- Loading branch information
Showing
2 changed files
with
267 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyOAE60e+jLXNtx01Ej6BBgo", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/sotetsuk/pgx/blob/sotetsuk%2Fcolab%2Fc4-mcts/colab/mcts_connect_four.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Simple MCTS example with MCTX\n", | ||
"\n", | ||
"Inspired by https://github.com/Carbon225/mctx-classic" | ||
], | ||
"metadata": { | ||
"id": "125IU6vLfn_z" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install pgx mctx" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "aPYOMRQZamsk", | ||
"outputId": "9d78b714-6f72-4cc6-e9ee-d62e89664511" | ||
}, | ||
"execution_count": 1, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Collecting pgx\n", | ||
" Downloading pgx-2.0.1-py3-none-any.whl (412 kB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m412.5/412.5 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hCollecting mctx\n", | ||
" Downloading mctx-0.0.3-py3-none-any.whl (44 kB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hRequirement already satisfied: jax>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from pgx) (0.4.20)\n", | ||
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pgx) (4.5.0)\n", | ||
"Collecting svgwrite (from pgx)\n", | ||
" Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.1/67.1 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hCollecting pgx-minatar==0.5.1 (from pgx)\n", | ||
" Downloading pgx_minatar-0.5.1-py3-none-any.whl (39 kB)\n", | ||
"Requirement already satisfied: chex>=0.0.8 in /usr/local/lib/python3.10/dist-packages (from mctx) (0.1.7)\n", | ||
"Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.10/dist-packages (from mctx) (0.4.20+cuda11.cudnn86)\n", | ||
"Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (1.4.0)\n", | ||
"Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (0.1.8)\n", | ||
"Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (1.23.5)\n", | ||
"Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.0.8->mctx) (0.12.0)\n", | ||
"Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.1->pgx) (0.2.0)\n", | ||
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.1->pgx) (3.3.0)\n", | ||
"Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.1->pgx) (1.11.3)\n", | ||
"Installing collected packages: pgx-minatar, svgwrite, pgx, mctx\n", | ||
"Successfully installed mctx-0.0.3 pgx-2.0.1 pgx-minatar-0.5.1 svgwrite-1.4.3\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 35, | ||
"metadata": { | ||
"id": "PSIbeBEzage1" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import NamedTuple\n", | ||
"from functools import partial\n", | ||
"\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import chex\n", | ||
"import mctx\n", | ||
"\n", | ||
"import pgx\n", | ||
"from pgx.experimental import act_randomly\n", | ||
"\n", | ||
"from IPython.display import *" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"class Config(NamedTuple):\n", | ||
" env_id: pgx.EnvId = \"connect_four\"\n", | ||
" seed: int = 0\n", | ||
" num_simulations: int = 1_000\n", | ||
" batch_size: int = 1\n", | ||
"\n", | ||
"config = Config()\n", | ||
"env = pgx.make(config.env_id)" | ||
], | ||
"metadata": { | ||
"id": "OtFgSQZaa2lL" | ||
}, | ||
"execution_count": 36, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def policy_fn(legal_action_mask):\n", | ||
" \"\"\"Return the logits of random policy. -Inf is set to illegal actions.\"\"\"\n", | ||
" chex.assert_shape(legal_action_mask, (env.num_actions,))\n", | ||
"\n", | ||
" logits = legal_action_mask.astype(jnp.float32)\n", | ||
" logits = jnp.where(legal_action_mask, logits, jnp.finfo(logits.dtype).min)\n", | ||
" return logits\n", | ||
"\n", | ||
"\n", | ||
"def value_fn(key, state):\n", | ||
" \"\"\"Return the value based on random rollout.\"\"\"\n", | ||
" chex.assert_rank(state.current_player, 0)\n", | ||
"\n", | ||
" def cond_fn(x):\n", | ||
" state, key = x\n", | ||
" return ~state.terminated\n", | ||
"\n", | ||
" def body_fn(x):\n", | ||
" state, key = x\n", | ||
" key, key_act, key_step = jax.random.split(key, 3)\n", | ||
" action = act_randomly(key_act, state.legal_action_mask)\n", | ||
" state = env.step(state, action, key_step)\n", | ||
" return (state, key)\n", | ||
"\n", | ||
" current_player = state.current_player\n", | ||
" state, _ = jax.lax.while_loop(cond_fn, body_fn, (state, key))\n", | ||
" return state.rewards[current_player]\n", | ||
"\n", | ||
"\n", | ||
"def recurrent_fn(params, rng_key, action, state):\n", | ||
" del params\n", | ||
" current_player = state.current_player\n", | ||
" state = env.step(state, action)\n", | ||
" logits = policy_fn(state.legal_action_mask)\n", | ||
" value = value_fn(rng_key, state)\n", | ||
" reward = state.rewards[current_player]\n", | ||
" value = jax.lax.select(state.terminated, 0.0, value)\n", | ||
" discount = jax.lax.select(state.terminated, 0.0, -1.0)\n", | ||
" recurrent_fn_output = mctx.RecurrentFnOutput(\n", | ||
" reward=reward,\n", | ||
" discount=discount,\n", | ||
" prior_logits=logits,\n", | ||
" value=value,\n", | ||
" )\n", | ||
" return recurrent_fn_output, state\n", | ||
"\n", | ||
"\n", | ||
"def run_mcts(key, state):\n", | ||
" key, subkey = jax.random.split(key)\n", | ||
" keys = jax.random.split(subkey, config.batch_size)\n", | ||
" key, subkey = jax.random.split(key)\n", | ||
"\n", | ||
" root = mctx.RootFnOutput(\n", | ||
" prior_logits=jax.vmap(policy_fn)(state.legal_action_mask),\n", | ||
" value=jax.vmap(value_fn)(keys, state),\n", | ||
" embedding=state\n", | ||
" )\n", | ||
" policy_output = mctx.muzero_policy(\n", | ||
" params=None,\n", | ||
" rng_key=subkey,\n", | ||
" root=root,\n", | ||
" invalid_actions=~state.legal_action_mask,\n", | ||
" recurrent_fn=jax.vmap(recurrent_fn, in_axes=(None, None, 0, 0)),\n", | ||
" num_simulations=config.num_simulations,\n", | ||
" max_depth=env.observation_shape[0] * env.observation_shape[1], # 42 in connect four\n", | ||
" qtransform=partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),\n", | ||
" dirichlet_fraction=0.0\n", | ||
" )\n", | ||
" return policy_output" | ||
], | ||
"metadata": { | ||
"id": "E0wJd8INbA9H" | ||
}, | ||
"execution_count": 37, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def vs_human(is_human_first=True):\n", | ||
" assert config.batch_size == 1\n", | ||
" key = jax.random.PRNGKey(config.seed)\n", | ||
" init_fn = jax.jit(jax.vmap(env.init))\n", | ||
" step_fn = jax.jit(jax.vmap(env.step))\n", | ||
"\n", | ||
" key, subkey = jax.random.split(key)\n", | ||
" keys = jax.random.split(subkey, config.batch_size)\n", | ||
" state: pgx.State = init_fn(keys)\n", | ||
"\n", | ||
" is_human_turn = is_human_first\n", | ||
" while True:\n", | ||
" clear_output(True)\n", | ||
" pgx.save_svg(state, \"/tmp/state.svg\")\n", | ||
" display_svg(SVG(\"/tmp/state.svg\"))\n", | ||
" if state.terminated.all():\n", | ||
" break\n", | ||
" if is_human_turn:\n", | ||
" action = int(input(\"Your action: \"))\n", | ||
" action = jnp.int32([action])\n", | ||
" else:\n", | ||
" policy_output = jax.jit(run_mcts)(key, state)\n", | ||
" action = policy_output.action\n", | ||
" state = step_fn(state, action)\n", | ||
" is_human_turn = not is_human_turn" | ||
], | ||
"metadata": { | ||
"id": "YItP5A_ibnTf" | ||
}, | ||
"execution_count": 38, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"vs_human()" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 301 | ||
}, | ||
"id": "EjfyCeo2e3KX", | ||
"outputId": "d3b63767-7908-4b19-95f6-7a5fd16ba729" | ||
}, | ||
"execution_count": 39, | ||
"outputs": [ | ||
{ | ||
"output_type": "display_data", | ||
"data": { | ||
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" baseProfile=\"full\" height=\"280.0\" version=\"1.1\" width=\"280.0\"><defs/><rect fill=\"white\" height=\"245\" width=\"245\" x=\"0\" y=\"0\"/><g transform=\"scale(1.0)\"><rect fill=\"white\" height=\"280\" width=\"280\" x=\"0\" y=\"0\"/><g transform=\"translate(17.5,17.5)\"><g id=\"vline\" stroke=\"gray\"><line stroke-width=\"1px\" x1=\"35\" x2=\"35\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"70\" x2=\"70\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"105\" x2=\"105\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"140\" x2=\"140\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"175\" x2=\"175\" y1=\"0\" y2=\"210\"/><line stroke-width=\"1px\" x1=\"210\" x2=\"210\" y1=\"0\" y2=\"210\"/></g><g id=\"vline\" stroke=\"gray\"><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"35\" y2=\"35\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"70\" y2=\"70\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"105\" y2=\"105\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"140\" y2=\"140\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"175\" y2=\"175\"/><line stroke-width=\"0.1px\" x1=\"0\" x2=\"245\" y1=\"210\" y2=\"210\"/></g><rect fill=\"black\" height=\"6\" stroke=\"black\" width=\"245\" x=\"0\" y=\"210\"/><rect fill=\"black\" height=\"245\" stroke=\"black\" width=\"6\" x=\"-6\" y=\"0\"/><rect fill=\"black\" height=\"245\" stroke=\"black\" width=\"6\" x=\"245\" y=\"0\"/><circle cx=\"52.5\" cy=\"17.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"17.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"52.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"52.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"52.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"87.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"87.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"87.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"87.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"122.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"122.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"122.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"122.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"157.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"157.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"157.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"157.5\" cy=\"157.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"157.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"17.5\" cy=\"192.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"52.5\" cy=\"192.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"87.5\" cy=\"192.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"122.5\" cy=\"192.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"157.5\" cy=\"192.5\" fill=\"black\" r=\"11.666666666666666\" stroke=\"black\"/><circle cx=\"192.5\" cy=\"192.5\" fill=\"white\" r=\"11.666666666666666\" stroke=\"black\"/></g></g></svg>" | ||
}, | ||
"metadata": {} | ||
} | ||
] | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters