Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make EvoTorch future-proof #77

Merged
merged 10 commits into from
Oct 2, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
.virtual_documents

# IPython
profile_default/
Expand Down
69 changes: 59 additions & 10 deletions examples/notebooks/Brax_Experiments_with_PGPE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@
"metadata": {},
"outputs": [],
"source": [
"ENV_NAME = \"brax::humanoid\" # solve the brax task named \"humanoid\"\n",
"# ENV_NAME = \"brax::old::humanoid\" # solve the \"humanoid\" task defined within 'brax.v1`\n",
"\n",
"problem = VecGymNE(\n",
" env=\"brax::humanoid\", # solve the brax task named \"humanoid\"\n",
" env=ENV_NAME,\n",
" network=policy,\n",
" #\n",
" # Collect observation stats, and use those stats to normalize incoming observations\n",
Expand All @@ -202,6 +205,19 @@
"problem, problem.solution_length"
]
},
{
"cell_type": "markdown",
"id": "bce02d7c-400c-4c22-9bb8-aa70fa4b1da2",
"metadata": {},
"source": [
"---\n",
"\n",
"**Note.**\n",
"At the time of writing this (15 June 2023), the [arXiv paper of EvoTorch](https://arxiv.org/abs/2302.12600v3) reports results based on the old implementations of the brax tasks (which were the default until brax v0.1.2). In brax version v0.9.0, these old task implementations moved into the namespace `brax.v1`. If you wish to reproduce the results reported in the arXiv paper of EvoTorch, you might want to specify the environment name as `\"brax::old::humanoid\"` (where the substring `\"old::\"` causes `VecGymNE` to instantiate the environment using the namespace `brax.v1`), so that you will observe scores and execution times compatible with the ones reported in that arXiv paper.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "95417793-3835-47b1-b10a-7f36e78fa3ad",
Expand Down Expand Up @@ -343,11 +359,21 @@
"import jax\n",
"\n",
"import brax\n",
"import brax.envs\n",
"import brax.jumpy as jp\n",
"\n",
"from brax.io import html\n",
"from brax.io import image\n",
"if ENV_NAME.startswith(\"brax::old::\"):\n",
" import brax.v1\n",
" import brax.v1.envs\n",
" import brax.v1.jumpy as jp\n",
" from brax.v1.io import html\n",
" from brax.v1.io import image\n",
"else:\n",
" try:\n",
" import jumpy as jp\n",
" except ImportError:\n",
" import brax.jumpy as jp\n",
" import brax.envs\n",
" from brax.io import html\n",
" from brax.io import image\n",
"\n",
"from IPython.display import HTML, Image\n",
"\n",
Expand Down Expand Up @@ -417,7 +443,11 @@
"metadata": {},
"outputs": [],
"source": [
"env = brax.envs.create(env_name=\"humanoid\")\n",
"if ENV_NAME.startswith(\"brax::old::\"):\n",
" env = brax.v1.envs.create(env_name=ENV_NAME[11:])\n",
"else:\n",
" env = brax.envs.create(env_name=ENV_NAME[6:])\n",
"\n",
"reset = jax.jit(env.reset)\n",
"step = jax.jit(env.step)"
]
Expand All @@ -438,7 +468,11 @@
"outputs": [],
"source": [
"seed = random.randint(0, (2 ** 32) - 1)\n",
"state = reset(rng=jp.random_prngkey(seed=seed))\n",
"\n",
"if hasattr(jp, \"random_prngkey\"):\n",
" state = reset(rng=jp.random_prngkey(seed=seed))\n",
"else:\n",
" state = reset(rng=jax.random.PRNGKey(seed=seed))\n",
"\n",
"h = None\n",
"states = []\n",
Expand Down Expand Up @@ -482,11 +516,26 @@
{
"cell_type": "code",
"execution_count": null,
"id": "29b9c3aa-8068-40bc-9ec2-e5ff759ed22c",
"id": "7ec60419-1ad0-4f19-bc1e-f2048577ea29",
"metadata": {},
"outputs": [],
"source": [
"if ENV_NAME.startswith(\"brax::old::\"):\n",
" env_sys = env.sys\n",
" states_to_render = [state.qp for state in states]\n",
"else:\n",
" env_sys = env.sys.replace(dt=env.dt)\n",
" states_to_render = [state.pipeline_state for state in states]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a07c70f6-2c93-43a1-b4c3-edd3f395302a",
"metadata": {},
"outputs": [],
"source": [
"HTML(html.render(env.sys, [state.qp for state in states]))"
"HTML(html.render(env_sys, states_to_render))"
]
}
],
Expand All @@ -506,7 +555,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.16"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions examples/notebooks/Gym_Experiments_with_PGPE_and_CoSyNE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
"source": [
"## Training Policies for Gym using PGPE and CoSyNE\n",
"\n",
"This example demonstrates how you can train policies using EvoTorch and Gym. To execute this example, you will need to install Gym's subpackages with:\n",
"This example demonstrates how you can train policies using EvoTorch and Gym. To execute this example, you will need to install the subpackages of `gymnasium` via:\n",
"\n",
"```bash\n",
" pip install 'gym[box2d,mujoco]'\n",
"pip install 'gymnasium[box2d,mujoco]'\n",
"```\n",
"\n",
"This example is based on our paper [1] where we describe the ClipUp optimiser and compare it to the Adam optimiser. In particular, we will re-implement the experiment for the \"LunarLanderContinuous-v2\" environment. "
Expand Down Expand Up @@ -279,7 +279,7 @@
"source": [
"#### References\n",
"\n",
"[1] Toklu, et. al. \"Clipup: a simple and powerful optimizer for distribution-based policy evolution.\" [International Conference on Parallel Problem Solving from Nature](https://dl.acm.org/doi/abs/10.1007/978-3-030-58115-2_36). Springer, Cham, 2020.\n",
"[1] Toklu, et. al. \"ClipUp: a simple and powerful optimizer for distribution-based policy evolution.\" [International Conference on Parallel Problem Solving from Nature](https://dl.acm.org/doi/abs/10.1007/978-3-030-58115-2_36). Springer, Cham, 2020.\n",
"\n",
"[2] Gomez, Faustino, et al. [\"Accelerated Neural Evolution through Cooperatively Coevolved Synapses.\"](https://www.jmlr.org/papers/volume9/gomez08a/gomez08a.pdf) Journal of Machine Learning Research 9.5 (2008)."
]
Expand All @@ -301,7 +301,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"Because this example focuses on the `Reacher-v4` reinforcement learning environment, `gym` with `mujoco` support is required. One can install the `mujoco` support for `gym` via:\n",
"\n",
"```bash\n",
"pip install 'gym[mujoco]'\n",
"pip install 'gymnasium[mujoco]'\n",
"```"
]
},
Expand Down Expand Up @@ -84,7 +84,7 @@
"\n",
"from typing import Iterable\n",
"\n",
"import gym"
"import gymnasium as gym"
]
},
{
Expand Down Expand Up @@ -124,7 +124,7 @@
"source": [
"## Definitions\n",
"\n",
"We begin our definitions with a helper function, $\\text{reacher_state}(\\text{observation})$ which extracts the state ($s_t$) of the robotic arm from the observation vector returned by the environment."
"We begin our definitions with a helper function, $\\text{reacher\\_state}(\\text{observation})$ which extracts the state ($s_t$) of the robotic arm from the observation vector returned by the environment."
]
},
{
Expand All @@ -147,7 +147,7 @@
"id": "4909a2aa-1f07-4c6e-b757-ddcbe7496a78",
"metadata": {},
"source": [
"We now define the function $\\text{predict_next_state}(s_t, a_t)$ which, given a state $s_t$ and an action $a_t$ ($t$ being the current timestep), returns the predicted next state $\\tilde{s}_{t+1}$.\n",
"We now define the function $\\text{predict\\_next\\_state}(s_t, a_t)$ which, given a state $s_t$ and an action $a_t$ ($t$ being the current timestep), returns the predicted next state $\\tilde{s}_{t+1}$.\n",
"\n",
"Within itself, this function uses the neural network $\\pi$ to make its predictions."
]
Expand All @@ -173,9 +173,9 @@
"source": [
"Let us now define a _plan_ $p_t$ as a series of actions planned for future timesteps, i.e.: $p_t = (a_t, a_{t+1}, a_{t+2}, ..., a_{t+(H-1)})$ where $H$ is the horizon, determining how far into the future we are planning.\n",
"\n",
"With this, we define the function $\\text{predict_plan_outcome}(s_t, p_t)$ which receives the current state $s_t$ and a plan $p_t$ and returns a predicted future state $\\tilde{s}_{t+H}$, which represents the predicted outcome of following the plan. Within $\\text{predict_plan_outcome}(\\cdot)$, the predictions are made with the help of $\\text{predict_next_state}(\\cdot)$ which in turn uses the neural network $\\pi$.\n",
"With this, we define the function $\\text{predict\\_plan\\_outcome}(s_t, p_t)$ which receives the current state $s_t$ and a plan $p_t$ and returns a predicted future state $\\tilde{s}_{t+H}$, which represents the predicted outcome of following the plan. Within $\\text{predict\\_plan\\_outcome}(\\cdot)$, the predictions are made with the help of $\\text{predict\\_next\\_state}(\\cdot)$ which in turn uses the neural network $\\pi$.\n",
"\n",
"An implementation detail to be noted here is that, $\\text{predict_plan_outcome}(\\cdot)$ expects not a single plan, but a batch of plans, and uses PyTorch's vectorization capabilities to make predictions for all those plans in a performant manner."
"An implementation detail to be noted here is that, $\\text{predict\\_plan\\_outcome}(\\cdot)$ expects not a single plan, but a batch of plans, and uses PyTorch's vectorization capabilities to make predictions for all those plans in a performant manner."
]
},
{
Expand Down Expand Up @@ -213,7 +213,7 @@
"\\begin{array}{c c l}\n",
" p_t =\n",
" & \\text{arg min} & ||(\\tilde{s}_{t+H}^x,\\tilde{s}_{t+H}^y)-(g^x, g^y)|| \\\\\n",
" & \\text{subject to} & \\tilde{s}_{t+H} = \\text{predict_plan_outcome}(s_t, p_t)\n",
" & \\text{subject to} & \\tilde{s}_{t+H} = \\text{predict\\_plan\\_outcome}(s_t, p_t)\n",
"\\end{array}\n",
"$$\n",
"\n",
Expand Down Expand Up @@ -304,19 +304,7 @@
"metadata": {},
"outputs": [],
"source": [
"from packaging.version import Version\n",
"old_render_api = Version(gym.__version__) < Version(\"0.26\")\n",
"\n",
"if old_render_api:\n",
" # For gym versions older than 0.26, we do not have to specify additional\n",
" # keyword arguments for human-mode rendering.\n",
" env = gym.make(\"Reacher-v4\")\n",
"else:\n",
" # For gym versions beginning with 0.26, we have to explicitly specify\n",
" # that the rendering mode is \"human\" if we wish to do the rendering on\n",
" # the screen.\n",
" env = gym.make(\"Reacher-v4\", render_mode=\"human\")\n",
"\n",
"env = gym.make(\"Reacher-v4\", render_mode=\"human\")\n",
"env"
]
},
Expand All @@ -337,15 +325,7 @@
"outputs": [],
"source": [
"def run_episode(visualize: bool = False):\n",
" reset_result = env.reset()\n",
" if isinstance(reset_result, tuple):\n",
" # If the result of the `reset()` method is a tuple, then we assume\n",
" # that it returned a tuple in the form (observation, info).\n",
" observation, _ = reset_result\n",
" else:\n",
" # If the result of the `reset()` method is not a tuple, then we\n",
" # assume that it returned the observation by itself.\n",
" observation = reset_result\n",
" observation, _ = env.reset()\n",
"\n",
" if visualize:\n",
" env.render()\n",
Expand All @@ -354,14 +334,8 @@
" action = do_planning(observation)\n",
" action = np.clip(action, -1.0, 1.0)\n",
" \n",
" step_result = env.step(action)\n",
" if len(step_result) == 5:\n",
" observation, reward, terminated, truncated, info = step_result\n",
" done = terminated | truncated\n",
" elif len(step_result) == 4:\n",
" observation, reward, done, info = step_result\n",
" else:\n",
" assert False, \"Unexpected number of items returned by `.step(...)`\"\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
" done = terminated | truncated\n",
" \n",
" if visualize:\n",
" env.render()\n",
Expand Down Expand Up @@ -419,7 +393,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
"version": "3.8.17"
}
},
"nbformat": 4,
Expand Down
Loading