Skip to content

NTT123/a0-jax

Repository files navigation

a0-jax

AlphaZero in JAX using deepmind mctx library.

pip install -r requirements.txt

Train agent

Connect-Two game

python train_agent.py --weight-decay=1e-2 --num-iterations=3

Connect-Four game

TF_CPP_MIN_LOG_LEVEL=2 \
python train_agent.py \
    --game_class="games.connect_four_game.Connect4Game" \
    --agent_class="policies.resnet_policy.ResnetPolicyValueNet" \
    --batch-size=4096 \
    --num_simulations_per_move=32 \
    --num_self_plays_per_iteration=102400 \
    --learning-rate=1e-2 \
    --num_iterations=500 \
    --lr-decay-steps=200000

A live Connect-4 agent is running at https://huggingface.co/spaces/ntt123/Connect-4-Game. We use tensorflow.js to run the policy on the browser.

Caro (Gomoku) game

TF_CPP_MIN_LOG_LEVEL=2 \
python3 train_agent.py \
    --game-class="games.caro_game.CaroGame" \
    --agent-class="policies.resnet_policy.ResnetPolicyValueNet128" \
    --selfplay-batch-size=1024 \
    --training-batch-size=1024 \
    --num-simulations-per-move=32 \
    --num-self-plays-per-iteration=102400 \
    --learning-rate=1e-2 \
    --random-seed=42 \
    --ckpt-filename="./caro_agent_9x9_128.ckpt" \
    --num-iterations=100 \
    --lr-decay-steps=500000

A live Caro agent is running at https://caro.ntt123.repl.co.

Go game

TF_CPP_MIN_LOG_LEVEL=2 \
python3 train_agent.py \
    --game-class="games.go_game.GoBoard9x9" \
    --agent-class="policies.resnet_policy.ResnetPolicyValueNet128" \
    --selfplay-batch-size=1024 \
    --training-batch-size=1024 \
    --num-simulations-per-move=32 \
    --num-self-plays-per-iteration=102400 \
    --learning-rate=1e-2 \
    --random-seed=42 \
    --ckpt-filename="./go_agent_9x9_128.ckpt" \
    --num-iterations=200 \
    --lr-decay-steps=1000000

A live Go agent is running at https://go.ntt123.repl.co. You can run the agent on your local machine with the go_web_app.py script.

We also have an interative colab notebook that runs the agent on GPU to reduce inference time.

Plot the search tree

python plot_search_tree.py 
# ./search_tree.png

Play

python play.py

TPU sponsor

Agents in the above demos are trained on Google TPUs sponsored by Google under the TPU Research Cloud program.

Releases

No releases published

Packages

No packages published