Skip to content

Commit

Permalink
Merge pull request #115 from FLAIROx/req-update
Browse files Browse the repository at this point in the history
Update requirement structure, increment to new version.
  • Loading branch information
amacrutherford authored Oct 14, 2024
2 parents 16a8f27 + 1c81f41 commit 42e7d63
Show file tree
Hide file tree
Showing 19 changed files with 56 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ RUN apt-get update && \
apt-get install -y tmux

#jaxmarl from source if needed, all the requirements
RUN pip install -e .
RUN pip install -e .[algs,dev]

USER ${MYUSER}

Expand Down
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

## Multi-Agent Reinforcement Learning in JAX

🎉 **Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver!**

JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine.

For more details, take a look at our [blog post](https://blog.foersterlab.com/jaxmarl/) or our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb), which walks through the basic usage.
Expand Down Expand Up @@ -72,7 +74,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca

<h2 name="install" id="install">Installation 🧗 </h2>

**Environments** - Before installing, ensure you have the correct [JAX version](https://github.com/google/jax#installation) for your hardware accelerator. The JaxMARL environments can be installed directly from PyPi:
**Environments** - Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi:

```
pip install jaxmarl
Expand All @@ -84,11 +86,15 @@ pip install jaxmarl
```
git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL
```
2. The requirements for IPPO & MAPPO can be installed with:
2. Install requirements:
```
pip install -e .
pip install -e .[algs]
export PYTHONPATH=./JaxMARL:$PYTHONPATH
```
3. For the fastest start, we reccoment using our Dockerfile, the usage of which is outlined below.

**Development** - If you would like to run our test suite, install the additonal dependencies with:
`pip install -e .[dev]`, after cloning the repository.

<h2 name="start" id="start">Quick Start 🚀 </h2>

Expand Down Expand Up @@ -151,10 +157,12 @@ JAX-native algorithms:
- [Mava](https://github.com/instadeepai/Mava): JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
- [PureJaxRL](https://github.com/luchris429/purejaxrl): JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.
- [Minimax](https://github.com/facebookresearch/minimax/): JAX implementations of autocurricula baselines for RL.
- [JaxIRL](https://github.com/FLAIROx/jaxirl?tab=readme-ov-file): JAX implementation of algorithms for inverse reinforcement learning.

JAX-native environments:
- [Gymnax](https://github.com/RobertTLange/gymnax): Implementations of classic RL tasks including classic control, bsuite and MinAtar.
- [Jumanji](https://github.com/instadeepai/jumanji): A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
- [Pgx](https://github.com/sotetsuk/pgx): JAX implementations of classic board games, such as Chess, Go and Shogi.
- [Brax](https://github.com/google/brax): A fully differentiable physics engine written in JAX, features continuous control tasks.
- [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid): Meta-RL gridworld environments inspired by XLand and MiniGrid.
- [Craftax](https://github.com/MichaelTMatthews/Craftax): (Crafter + NetHack) in JAX.
2 changes: 1 addition & 1 deletion jaxmarl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .registration import make, registered_envs

__all__ = ["make", "registered_envs"]
__version__ = "0.0.5"
__version__ = "0.0.6"
2 changes: 1 addition & 1 deletion jaxmarl/environments/hanabi/hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import chex
from typing import Tuple, Dict
from functools import partial
from gymnax.environments.spaces import Discrete
from jaxmarl.environments.spaces import Discrete
from .hanabi_game import HanabiGame, State


Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mabrax/mabrax_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Literal, Optional, Tuple
import chex
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
from gymnax.environments import spaces
from jaxmarl.environments import spaces
from brax import envs
import jax
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
from jaxmarl.environments.mpe.default_params import *
import chex
from gymnax.environments.spaces import Box, Discrete
from jaxmarl.environments.spaces import Box, Discrete
from flax import struct
from typing import Tuple, Optional, Dict
from functools import partial
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import State, SimpleMPE
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box
from jaxmarl.environments.spaces import Box


class SimpleAdversaryMPE(SimpleMPE):
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box, Discrete
from jaxmarl.environments.spaces import Box, Discrete

SPEAKER = "alice_0"
LISTENER = "bob_0"
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_facmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple, Dict
from functools import partial
from jaxmarl.environments.mpe.simple import State, SimpleMPE
from gymnax.environments.spaces import Box
from jaxmarl.environments.spaces import Box
from jaxmarl.environments.mpe.default_params import *


Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box
from jaxmarl.environments.spaces import Box

# Obstacle Colours
COLOUR_1 = jnp.array([0.1, 0.9, 0.1])
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box, Discrete
from jaxmarl.environments.spaces import Box, Discrete

# Obstacle Colours
OBS_COLOUR = [(191, 64, 64), (64, 191, 64), (64, 64, 191)]
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_speaker_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple, Dict
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box, Discrete
from jaxmarl.environments.spaces import Box, Discrete

SPEAKER = "speaker_0"
LISTENER = "listener_0"
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_spread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box
from jaxmarl.environments.spaces import Box


class SimpleSpreadMPE(SimpleMPE):
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple, Dict
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from gymnax.environments.spaces import Box
from jaxmarl.environments.spaces import Box
from jaxmarl.environments.mpe.default_params import *


Expand Down
3 changes: 1 addition & 2 deletions jaxmarl/environments/mpe/simple_world_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
OBS_COLOUR,
)
from jaxmarl.environments.mpe.default_params import *
from gymnax.environments.spaces import Box, Discrete

from jaxmarl.environments.spaces import Box, Discrete

# NOTE food and forests are part of world.landmarks

Expand Down
1 change: 1 addition & 0 deletions jaxmarl/environments/spaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Built off Gymnax spaces.py, this module contains jittable classes for action and observation spaces. """
from typing import Tuple, Union, Sequence
from collections import OrderedDict
import chex
Expand Down
32 changes: 30 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ include = ['jaxmarl*']

[tool.setuptools.dynamic]
version = {attr = "jaxmarl.__version__"}
dependencies = {file = ["requirements/requirements.txt"]}

[project]
name = "jaxmarl"
Expand All @@ -17,7 +16,7 @@ description = "Multi-Agent Reinforcement Learning with JAX"
authors = [
{name = "Foerster Lab for AI Research", email = "arutherford@robots.ox.ac.uk"},
]
dynamic = ["version", "dependencies"]
dynamic = ["version"]
license = {file = "LICENSE"}
requires-python = ">=3.10"
classifiers = [
Expand All @@ -31,6 +30,35 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"jax>=0.4.16.0,<=0.4.25",
"jaxlib>=0.4.16.0,<=0.4.25",
"flax",
"safetensors",
"chex",
"brax==0.10.3",
"mujoco==3.1.3",
"matplotlib",
"pillow",
"scipy<=1.12",
"gymnax",
]

[project.optional-dependencies]
algs = [
"optax",
"distrax",
"flashbax==0.1.0",
"wandb",
"hydra-core>=1.3.2",
"omegaconf>=2.3.0",
"pettingzoo>=1.24.3",
"tqdm>=4.66.0",
]
dev = [
"pytest",
"pygame",
]

[project.urls]
"Homepage" = "https://github.com/FLAIROx/JaxMARL"
Expand Down
26 changes: 0 additions & 26 deletions requirements/requirements.txt

This file was deleted.

1 change: 0 additions & 1 deletion tests/hanabi/test_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax
from jax import numpy as jnp
from jaxmarl import make
from jaxmarl.wrappers.baselines import LogWrapper

env = make("hanabi")
dir_path = os.path.dirname(os.path.realpath(__file__))
Expand Down

0 comments on commit 42e7d63

Please sign in to comment.