diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml
index b45393ba..16b5e203 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yaml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yaml
@@ -30,11 +30,14 @@ body:
description: The skrl version can be obtained with the command `pip show skrl`.
options:
- ---
+ - 1.2.0
+ - 1.1.0
- 1.0.0
- 1.0.0-rc2
- 1.0.0-rc1
- 0.10.2 or 0.10.1
- 0.10.0 or earlier
+ - develop branch
validations:
required: true
- type: input
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 13e167ed..cb058434 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,20 @@
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+## [1.2.0] - 2024-06-23
+### Added
+- Define the `environment_info` trainer config to log environment info (PyTorch implementation)
+- Add support to automatically compute the write and checkpoint intervals and make it the default option
+- Single forward-pass in shared models
+- Distributed multi-GPU and multi-node learning (PyTorch implementation)
+
+### Changed
+- Update Orbit-related source code and docs to Isaac Lab
+
+### Fixed
+- Move the batch sampling inside gradient step loop for DDPG and TD3
+- Perform JAX computation on the selected device
+
## [1.1.0] - 2024-02-12
### Added
- MultiCategorical mixin to operate MultiDiscrete action spaces
diff --git a/README.md b/README.md
index 65b7b391..728dfcf0 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@
SKRL - Reinforcement Learning library
-**skrl** is an open-source modular library for Reinforcement Learning written in Python (on top of [PyTorch](https://pytorch.org/) and [JAX](https://jax.readthedocs.io)) and designed with a focus on modularity, readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI [Gym](https://www.gymlibrary.dev) / Farama [Gymnasium](https://gymnasium.farama.org) and [DeepMind](https://github.com/deepmind/dm_env) and other environment interfaces, it allows loading and configuring [NVIDIA Isaac Gym](https://developer.nvidia.com/isaac-gym/), [NVIDIA Isaac Orbit](https://isaac-orbit.github.io/orbit/index.html) and [NVIDIA Omniverse Isaac Gym](https://docs.omniverse.nvidia.com/isaacsim/latest/tutorial_gym_isaac_gym.html) environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run.
+**skrl** is an open-source modular library for Reinforcement Learning written in Python (on top of [PyTorch](https://pytorch.org/) and [JAX](https://jax.readthedocs.io)) and designed with a focus on modularity, readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI [Gym](https://www.gymlibrary.dev) / Farama [Gymnasium](https://gymnasium.farama.org) and [DeepMind](https://github.com/deepmind/dm_env) and other environment interfaces, it allows loading and configuring [NVIDIA Isaac Gym](https://developer.nvidia.com/isaac-gym/), [NVIDIA Omniverse Isaac Gym](https://docs.omniverse.nvidia.com/isaacsim/latest/tutorial_gym_isaac_gym.html) and [NVIDIA Isaac Lab](https://isaac-sim.github.io/IsaacLab/index.html) environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run.
diff --git a/docs/source/_static/imgs/example_isaac_orbit.png b/docs/source/_static/imgs/example_isaaclab.png
similarity index 100%
rename from docs/source/_static/imgs/example_isaac_orbit.png
rename to docs/source/_static/imgs/example_isaaclab.png
diff --git a/docs/source/api/agents/a2c.rst b/docs/source/api/agents/a2c.rst
index c98f6448..dfb55a02 100644
--- a/docs/source/api/agents/a2c.rst
+++ b/docs/source/api/agents/a2c.rst
@@ -232,6 +232,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/amp.rst b/docs/source/api/agents/amp.rst
index c993b67a..e6937f76 100644
--- a/docs/source/api/agents/amp.rst
+++ b/docs/source/api/agents/amp.rst
@@ -237,6 +237,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/cem.rst b/docs/source/api/agents/cem.rst
index 68245818..1bcbee01 100644
--- a/docs/source/api/agents/cem.rst
+++ b/docs/source/api/agents/cem.rst
@@ -175,6 +175,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - \-
+ - .. centered:: :math:`\square`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/ddpg.rst b/docs/source/api/agents/ddpg.rst
index 00972f3f..461fda20 100644
--- a/docs/source/api/agents/ddpg.rst
+++ b/docs/source/api/agents/ddpg.rst
@@ -47,10 +47,10 @@ Learning algorithm
|
| :literal:`_update(...)`
-| :green:`# sample a batch from memory`
-| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
+| :green:`# sample a batch from memory`
+| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`a' \leftarrow \mu_{\theta_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow Q_{\phi_{target}}(s', a')`
@@ -236,6 +236,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst
index 91744c33..f2ac0029 100644
--- a/docs/source/api/agents/ddqn.rst
+++ b/docs/source/api/agents/ddqn.rst
@@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst
index 605e57ad..ed196366 100644
--- a/docs/source/api/agents/dqn.rst
+++ b/docs/source/api/agents/dqn.rst
@@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/ppo.rst b/docs/source/api/agents/ppo.rst
index 0bf7b37b..f3c84f20 100644
--- a/docs/source/api/agents/ppo.rst
+++ b/docs/source/api/agents/ppo.rst
@@ -248,6 +248,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/rpo.rst b/docs/source/api/agents/rpo.rst
index 61947dff..f769e9c7 100644
--- a/docs/source/api/agents/rpo.rst
+++ b/docs/source/api/agents/rpo.rst
@@ -285,6 +285,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst
index c55720cc..2c17fc3a 100644
--- a/docs/source/api/agents/sac.rst
+++ b/docs/source/api/agents/sac.rst
@@ -244,6 +244,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/td3.rst b/docs/source/api/agents/td3.rst
index ee200de7..5d54ae4f 100644
--- a/docs/source/api/agents/td3.rst
+++ b/docs/source/api/agents/td3.rst
@@ -47,10 +47,10 @@ Learning algorithm
|
| :literal:`_update(...)`
-| :green:`# sample a batch from memory`
-| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
+| :green:`# sample a batch from memory`
+| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# target policy smoothing`
| :math:`a' \leftarrow \mu_{\theta_{target}}(s')`
| :math:`noise \leftarrow \text{clip}(` :guilabel:`smooth_regularization_noise` :math:`, -c, c) \qquad` with :math:`c` as :guilabel:`smooth_regularization_clip`
@@ -258,6 +258,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/trpo.rst b/docs/source/api/agents/trpo.rst
index 21460fe9..c482bf87 100644
--- a/docs/source/api/agents/trpo.rst
+++ b/docs/source/api/agents/trpo.rst
@@ -282,6 +282,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst
index eff8bb77..f25e492f 100644
--- a/docs/source/api/config/frameworks.rst
+++ b/docs/source/api/config/frameworks.rst
@@ -7,6 +7,65 @@ Configurations for behavior modification of Machine Learning (ML) frameworks.
+PyTorch
+-------
+
+PyTorch specific configuration
+
+.. raw:: html
+
+
+
+API
+^^^
+
+.. py:data:: skrl.config.torch.device
+ :type: torch.device
+ :value: "cuda:${LOCAL_RANK}" | "cpu"
+
+ Default device
+
+ The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise
+
+.. py:data:: skrl.config.local_rank
+ :type: int
+ :value: 0
+
+ The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
+
+ This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
+ See `torch.distributed `_ for more details
+
+.. py:data:: skrl.config.rank
+ :type: int
+ :value: 0
+
+ The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
+
+ This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
+ See `torch.distributed `_ for more details
+
+.. py:data:: skrl.config.world_size
+ :type: int
+ :value: 1
+
+ The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
+
+ This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
+ See `torch.distributed `_ for more details
+
+.. py:data:: skrl.config.is_distributed
+ :type: bool
+ :value: False
+
+ Whether if running in a distributed environment
+
+ This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
+
+.. raw:: html
+
+
+
JAX
---
diff --git a/docs/source/api/envs.rst b/docs/source/api/envs.rst
index 3b3091da..6b32cea7 100644
--- a/docs/source/api/envs.rst
+++ b/docs/source/api/envs.rst
@@ -7,8 +7,8 @@ Environments
Wrapping (single-agent)
Wrapping (multi-agents)
Isaac Gym environments
- Isaac Orbit environments
Omniverse Isaac Gym environments
+ Isaac Lab environments
The environment plays a fundamental and crucial role in defining the RL setup. It is the place where the agent interacts, and it is responsible for providing the agent with information about its current state, as well as the rewards/penalties associated with each action.
@@ -16,7 +16,7 @@ The environment plays a fundamental and crucial role in defining the RL setup. I
-Grouped in this section you will find how to load environments from NVIDIA Isaac Gym, Isaac Orbit and Omniverse Isaac Gym with a simple function.
+Grouped in this section you will find how to load environments from NVIDIA Isaac Gym, Omniverse Isaac Gym and Isaac Lab with a simple function.
In addition, you will be able to :doc:`wrap single-agent ` and :doc:`multi-agent ` RL environment interfaces.
@@ -29,10 +29,10 @@ In addition, you will be able to :doc:`wrap single-agent ` and :d
* - :doc:`Isaac Gym environments `
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
- * - :doc:`Isaac Orbit environments `
+ * - :doc:`Omniverse Isaac Gym environments `
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
- * - :doc:`Omniverse Isaac Gym environments `
+ * - :doc:`Isaac Lab environments `
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
@@ -57,10 +57,10 @@ In addition, you will be able to :doc:`wrap single-agent ` and :d
* - Isaac Gym (previews)
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
- * - Isaac Orbit
+ * - Omniverse Isaac Gym |_5| |_5| |_5| |_5| |_2|
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
- * - Omniverse Isaac Gym |_5| |_5| |_5| |_5| |_2|
+ * - Isaac Lab
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
* - PettingZoo
diff --git a/docs/source/api/envs/isaac_orbit.rst b/docs/source/api/envs/isaac_orbit.rst
deleted file mode 100644
index b7b603dd..00000000
--- a/docs/source/api/envs/isaac_orbit.rst
+++ /dev/null
@@ -1,90 +0,0 @@
-Isaac Orbit environments
-========================
-
-.. image:: ../../_static/imgs/example_isaac_orbit.png
- :width: 100%
- :align: center
- :alt: Isaac Orbit environments
-
-.. raw:: html
-
-
-
-Environments
-------------
-
-The repository https://github.com/NVIDIA-Omniverse/Orbit provides the example reinforcement learning environments for Isaac orbit.
-
-These environments can be easily loaded and configured by calling a single function provided with this library. This function also makes it possible to configure the environment from the command line arguments (see Isaac Orbit's `Running an RL environment `_) or from its parameters (:literal:`task_name`, :literal:`num_envs`, :literal:`headless`, and :literal:`cli_args`).
-
-.. note::
-
- The command line arguments has priority over the function parameters.
-
-.. note::
-
- Isaac Orbit environments implement a functionality to get their configuration from the command line. Setting the :literal:`headless` option from the trainer configuration will not work. In this case, it is necessary to set the load function's :literal:`headless` argument to True or to invoke the scripts as follows: :literal:`orbit -p script.py --headless`.
-
-.. raw:: html
-
-
-
-Usage
-^^^^^
-
-.. tabs::
-
- .. tab:: Function parameters
-
- .. tabs::
-
- .. group-tab:: |_4| |pytorch| |_4|
-
- .. literalinclude:: ../../snippets/loaders.py
- :language: python
- :emphasize-lines: 2, 5
- :start-after: [start-isaac-orbit-envs-parameters-torch]
- :end-before: [end-isaac-orbit-envs-parameters-torch]
-
- .. group-tab:: |_4| |jax| |_4|
-
- .. literalinclude:: ../../snippets/loaders.py
- :language: python
- :emphasize-lines: 2, 5
- :start-after: [start-isaac-orbit-envs-parameters-jax]
- :end-before: [end-isaac-orbit-envs-parameters-jax]
-
- .. tab:: Command line arguments (priority)
-
- .. tabs::
-
- .. group-tab:: |_4| |pytorch| |_4|
-
- .. literalinclude:: ../../snippets/loaders.py
- :language: python
- :emphasize-lines: 2, 5
- :start-after: [start-isaac-orbit-envs-cli-torch]
- :end-before: [end-isaac-orbit-envs-cli-torch]
-
- .. group-tab:: |_4| |jax| |_4|
-
- .. literalinclude:: ../../snippets/loaders.py
- :language: python
- :emphasize-lines: 2, 5
- :start-after: [start-isaac-orbit-envs-cli-jax]
- :end-before: [end-isaac-orbit-envs-cli-jax]
-
- Run the main script passing the configuration as command line arguments. For example:
-
- .. code-block::
-
- orbit -p main.py --task Isaac-Cartpole-v0
-
-.. raw:: html
-
-
-
-API
-^^^
-
-.. autofunction:: skrl.envs.loaders.torch.load_isaac_orbit_env
diff --git a/docs/source/api/envs/isaaclab.rst b/docs/source/api/envs/isaaclab.rst
new file mode 100644
index 00000000..9777b4a0
--- /dev/null
+++ b/docs/source/api/envs/isaaclab.rst
@@ -0,0 +1,90 @@
+Isaac Lab environments
+======================
+
+.. image:: ../../_static/imgs/example_isaaclab.png
+ :width: 100%
+ :align: center
+ :alt: Isaac Lab environments
+
+.. raw:: html
+
+
+
+Environments
+------------
+
+The repository https://github.com/isaac-sim/IsaacLab provides the example reinforcement learning environments for Isaac Lab (Orbit and Omniverse Isaac Gym unification).
+
+These environments can be easily loaded and configured by calling a single function provided with this library. This function also makes it possible to configure the environment from the command line arguments (see Isaac Lab's `Training with an RL Agent `_) or from its parameters (:literal:`task_name`, :literal:`num_envs`, :literal:`headless`, and :literal:`cli_args`).
+
+.. note::
+
+ The command line arguments has priority over the function parameters.
+
+.. note::
+
+ Isaac Lab environments implement a functionality to get their configuration from the command line. Setting the :literal:`headless` option from the trainer configuration will not work. In this case, it is necessary to set the load function's :literal:`headless` argument to True or to invoke the scripts as follows: :literal:`isaaclab -p script.py --headless`.
+
+.. raw:: html
+
+
+
+Usage
+^^^^^
+
+.. tabs::
+
+ .. tab:: Function parameters
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../../snippets/loaders.py
+ :language: python
+ :emphasize-lines: 2, 5
+ :start-after: [start-isaaclab-envs-parameters-torch]
+ :end-before: [end-isaaclab-envs-parameters-torch]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../../snippets/loaders.py
+ :language: python
+ :emphasize-lines: 2, 5
+ :start-after: [start-isaaclab-envs-parameters-jax]
+ :end-before: [end-isaaclab-envs-parameters-jax]
+
+ .. tab:: Command line arguments (priority)
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../../snippets/loaders.py
+ :language: python
+ :emphasize-lines: 2, 5
+ :start-after: [start-isaaclab-envs-cli-torch]
+ :end-before: [end-isaaclab-envs-cli-torch]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../../snippets/loaders.py
+ :language: python
+ :emphasize-lines: 2, 5
+ :start-after: [start-isaaclab-envs-cli-jax]
+ :end-before: [end-isaaclab-envs-cli-jax]
+
+ Run the main script passing the configuration as command line arguments. For example:
+
+ .. code-block::
+
+ isaaclab -p main.py --task Isaac-Cartpole-v0
+
+.. raw:: html
+
+
+
+API
+^^^
+
+.. autofunction:: skrl.envs.loaders.torch.load_isaaclab_env
diff --git a/docs/source/api/envs/wrapping.rst b/docs/source/api/envs/wrapping.rst
index a725a7e9..8fc99ddb 100644
--- a/docs/source/api/envs/wrapping.rst
+++ b/docs/source/api/envs/wrapping.rst
@@ -14,8 +14,8 @@ This library works with a common API to interact with the following RL environme
* `DeepMind `_
* `robosuite `_
* `NVIDIA Isaac Gym `_ (preview 2, 3 and 4)
-* `NVIDIA Isaac Orbit `_
* `NVIDIA Omniverse Isaac Gym `_
+* `NVIDIA Isaac Lab `_
To operate with them and to support interoperability between these non-compatible interfaces, a **wrapping mechanism is provided** as shown in the diagram below
@@ -44,6 +44,24 @@ Usage
.. tabs::
+ .. tab:: Isaac Lab
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../../snippets/wrapping.py
+ :language: python
+ :start-after: [pytorch-start-isaaclab]
+ :end-before: [pytorch-end-isaaclab]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../../snippets/wrapping.py
+ :language: python
+ :start-after: [jax-start-isaaclab]
+ :end-before: [jax-end-isaaclab]
+
.. tab:: Omniverse Isaac Gym
.. tabs::
@@ -84,24 +102,6 @@ Usage
:start-after: [jax-start-omniverse-isaacgym-mt]
:end-before: [jax-end-omniverse-isaacgym-mt]
- .. tab:: Isaac Orbit
-
- .. tabs::
-
- .. group-tab:: |_4| |pytorch| |_4|
-
- .. literalinclude:: ../../snippets/wrapping.py
- :language: python
- :start-after: [pytorch-start-isaac-orbit]
- :end-before: [pytorch-end-isaac-orbit]
-
- .. group-tab:: |_4| |jax| |_4|
-
- .. literalinclude:: ../../snippets/wrapping.py
- :language: python
- :start-after: [jax-start-isaac-orbit]
- :end-before: [jax-end-isaac-orbit]
-
.. tab:: Isaac Gym
.. tabs::
@@ -365,7 +365,7 @@ Internal API (PyTorch)
.. automethod:: __init__
-.. autoclass:: skrl.envs.wrappers.torch.IsaacOrbitWrapper
+.. autoclass:: skrl.envs.wrappers.torch.IsaacLabWrapper
:undoc-members:
:show-inheritance:
:members:
@@ -443,7 +443,7 @@ Internal API (JAX)
.. automethod:: __init__
-.. autoclass:: skrl.envs.wrappers.jax.IsaacOrbitWrapper
+.. autoclass:: skrl.envs.wrappers.jax.IsaacLabWrapper
:undoc-members:
:show-inheritance:
:members:
diff --git a/docs/source/api/models/shared_model.rst b/docs/source/api/models/shared_model.rst
index f2dca04b..6b4ebf2d 100644
--- a/docs/source/api/models/shared_model.rst
+++ b/docs/source/api/models/shared_model.rst
@@ -7,7 +7,7 @@ Sometimes it is desirable to define models that use shared layers or network to
* Reduce the number of parameters in the whole system.
-* Make the computation more efficient.
+* Make the computation more efficient (single forward-pass).
.. raw:: html
@@ -42,7 +42,24 @@ The code snippet below shows how to define a shared model. The following practic
.. group-tab:: |_4| |pytorch| |_4|
- .. literalinclude:: ../../snippets/shared_model.py
- :language: python
- :start-after: [start-mlp-torch]
- :end-before: [end-mlp-torch]
+ .. tabs::
+
+ .. group-tab:: Single forward-pass
+
+ .. warning::
+
+ The implementation described for single forward-pass requires that the value-pass always follows the policy-pass (e.g.: ``PPO``) which may not be generalized to other algorithms.
+
+ If this requirement is not met, other forms of "chaching" the shared layers/network output could be implemented.
+
+ .. literalinclude:: ../../snippets/shared_model.py
+ :language: python
+ :start-after: [start-mlp-single-forward-pass-torch]
+ :end-before: [end-mlp-single-forward-pass-torch]
+
+ .. group-tab:: Multiple forward-pass
+
+ .. literalinclude:: ../../snippets/shared_model.py
+ :language: python
+ :start-after: [start-mlp-multi-forward-pass-torch]
+ :end-before: [end-mlp-multi-forward-pass-torch]
diff --git a/docs/source/api/multi_agents/ippo.rst b/docs/source/api/multi_agents/ippo.rst
index 9a259326..862fd2b9 100644
--- a/docs/source/api/multi_agents/ippo.rst
+++ b/docs/source/api/multi_agents/ippo.rst
@@ -239,6 +239,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/multi_agents/mappo.rst b/docs/source/api/multi_agents/mappo.rst
index c875ac6a..7ae18de3 100644
--- a/docs/source/api/multi_agents/mappo.rst
+++ b/docs/source/api/multi_agents/mappo.rst
@@ -240,6 +240,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/utils.rst b/docs/source/api/utils.rst
index a86fb2da..ff050aaa 100644
--- a/docs/source/api/utils.rst
+++ b/docs/source/api/utils.rst
@@ -25,7 +25,7 @@ A set of utilities and configurations for managing an RL setup is provided as pa
- .. centered:: |_4| |pytorch| |_4|
- .. centered:: |_4| |jax| |_4|
* - :doc:`ML frameworks ` configuration |_5| |_5| |_5| |_5| |_5| |_2|
- - .. centered:: :math:`\square`
+ - .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
.. list-table::
diff --git a/docs/source/api/utils/omniverse_isaacgym_utils.rst b/docs/source/api/utils/omniverse_isaacgym_utils.rst
index a5bfe412..57167cca 100644
--- a/docs/source/api/utils/omniverse_isaacgym_utils.rst
+++ b/docs/source/api/utils/omniverse_isaacgym_utils.rst
@@ -17,7 +17,7 @@ Control of robotic manipulators
Differential inverse kinematics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-This implementation attempts to unify under a single and reusable function the whole set of procedures used to compute the inverse kinematics of a robotic manipulator, originally shown in the Isaac Orbit framework's task space controllers section, but this time for Omniverse Isaac Gym.
+This implementation attempts to unify under a single and reusable function the whole set of procedures used to compute the inverse kinematics of a robotic manipulator, originally shown in the Isaac Lab (Orbit then) framework's task space controllers section, but this time for Omniverse Isaac Gym.
:math:`\Delta\theta =` :guilabel:`scale` :math:`J^\dagger \, \vec{e}`
diff --git a/docs/source/conf.py b/docs/source/conf.py
index f04c7874..42df4615 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -16,7 +16,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
- release = version = "1.1.0"
+ release = version = "1.2.0"
master_doc = "index"
diff --git a/docs/source/examples/isaacgym/torch_allegro_hand_ppo.py b/docs/source/examples/isaacgym/torch_allegro_hand_ppo.py
index 2e2b950d..2d68985d 100644
--- a/docs/source/examples/isaacgym/torch_allegro_hand_ppo.py
+++ b/docs/source/examples/isaacgym/torch_allegro_hand_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment using the easy-to-use API from NVIDIA
diff --git a/docs/source/examples/isaacgym/torch_allegro_kuka_ppo.py b/docs/source/examples/isaacgym/torch_allegro_kuka_ppo.py
index 16f1fcbe..121e9e76 100644
--- a/docs/source/examples/isaacgym/torch_allegro_kuka_ppo.py
+++ b/docs/source/examples/isaacgym/torch_allegro_kuka_ppo.py
@@ -49,9 +49,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment using the easy-to-use API from NVIDIA
diff --git a/docs/source/examples/isaacgym/torch_ant_ppo.py b/docs/source/examples/isaacgym/torch_ant_ppo.py
index c3678f40..2d5f9e12 100644
--- a/docs/source/examples/isaacgym/torch_ant_ppo.py
+++ b/docs/source/examples/isaacgym/torch_ant_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_anymal_ppo.py b/docs/source/examples/isaacgym/torch_anymal_ppo.py
index 7483bcff..a6a464c4 100644
--- a/docs/source/examples/isaacgym/torch_anymal_ppo.py
+++ b/docs/source/examples/isaacgym/torch_anymal_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_ball_balance_ppo.py b/docs/source/examples/isaacgym/torch_ball_balance_ppo.py
index af67c206..3eee7e90 100644
--- a/docs/source/examples/isaacgym/torch_ball_balance_ppo.py
+++ b/docs/source/examples/isaacgym/torch_ball_balance_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_cartpole_ppo.py b/docs/source/examples/isaacgym/torch_cartpole_ppo.py
index 65c104d1..42898a00 100644
--- a/docs/source/examples/isaacgym/torch_cartpole_ppo.py
+++ b/docs/source/examples/isaacgym/torch_cartpole_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_pick_ppo.py b/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_pick_ppo.py
index 2d34a390..6e150ac2 100644
--- a/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_pick_ppo.py
+++ b/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_pick_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_place_ppo.py b/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_place_ppo.py
index ef9674f9..d39aeb21 100644
--- a/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_place_ppo.py
+++ b/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_place_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_screw_ppo.py b/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_screw_ppo.py
index 151db1d9..4b2e3b30 100644
--- a/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_screw_ppo.py
+++ b/docs/source/examples/isaacgym/torch_factory_task_nut_bolt_screw_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_franka_cabinet_ppo.py b/docs/source/examples/isaacgym/torch_franka_cabinet_ppo.py
index ad1dff1b..a71b39e7 100644
--- a/docs/source/examples/isaacgym/torch_franka_cabinet_ppo.py
+++ b/docs/source/examples/isaacgym/torch_franka_cabinet_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_franka_cube_stack_ppo.py b/docs/source/examples/isaacgym/torch_franka_cube_stack_ppo.py
index bfdaf5ca..1d89d852 100644
--- a/docs/source/examples/isaacgym/torch_franka_cube_stack_ppo.py
+++ b/docs/source/examples/isaacgym/torch_franka_cube_stack_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_humanoid_ppo.py b/docs/source/examples/isaacgym/torch_humanoid_ppo.py
index 65779b8a..ed85ef2f 100644
--- a/docs/source/examples/isaacgym/torch_humanoid_ppo.py
+++ b/docs/source/examples/isaacgym/torch_humanoid_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_ingenuity_ppo.py b/docs/source/examples/isaacgym/torch_ingenuity_ppo.py
index 43289554..9c99fbdc 100644
--- a/docs/source/examples/isaacgym/torch_ingenuity_ppo.py
+++ b/docs/source/examples/isaacgym/torch_ingenuity_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment using the easy-to-use API from NVIDIA
diff --git a/docs/source/examples/isaacgym/torch_quadcopter_ppo.py b/docs/source/examples/isaacgym/torch_quadcopter_ppo.py
index 118f6ce8..0e8018cd 100644
--- a/docs/source/examples/isaacgym/torch_quadcopter_ppo.py
+++ b/docs/source/examples/isaacgym/torch_quadcopter_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_shadow_hand_ppo.py b/docs/source/examples/isaacgym/torch_shadow_hand_ppo.py
index 08d12979..f868a71d 100644
--- a/docs/source/examples/isaacgym/torch_shadow_hand_ppo.py
+++ b/docs/source/examples/isaacgym/torch_shadow_hand_ppo.py
@@ -49,9 +49,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacgym/torch_trifinger_ppo.py b/docs/source/examples/isaacgym/torch_trifinger_ppo.py
index d8f59ea2..cf313bf6 100644
--- a/docs/source/examples/isaacgym/torch_trifinger_ppo.py
+++ b/docs/source/examples/isaacgym/torch_trifinger_ppo.py
@@ -49,9 +49,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Isaac Gym environment
diff --git a/docs/source/examples/isaacorbit/jax_ant_ddpg.py b/docs/source/examples/isaaclab/jax_ant_ddpg.py
similarity index 78%
rename from docs/source/examples/isaacorbit/jax_ant_ddpg.py
rename to docs/source/examples/isaaclab/jax_ant_ddpg.py
index 25fe1a26..381839bb 100644
--- a/docs/source/examples/isaacorbit/jax_ant_ddpg.py
+++ b/docs/source/examples/isaaclab/jax_ant_ddpg.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ddpg import DDPG, DDPG_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, Model
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
@@ -57,9 +40,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
@@ -69,8 +49,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0", num_envs=64)
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0", num_envs=64)
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_ant_ppo.py b/docs/source/examples/isaaclab/jax_ant_ppo.py
similarity index 82%
rename from docs/source/examples/isaacorbit/jax_ant_ppo.py
rename to docs/source/examples/isaaclab/jax_ant_ppo.py
index ba0cc9f3..1f3c62d1 100644
--- a/docs/source/examples/isaacorbit/jax_ant_ppo.py
+++ b/docs/source/examples/isaaclab/jax_ant_ppo.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -72,8 +52,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_ant_sac.py b/docs/source/examples/isaaclab/jax_ant_sac.py
similarity index 79%
rename from docs/source/examples/isaacorbit/jax_ant_sac.py
rename to docs/source/examples/isaaclab/jax_ant_sac.py
index 0b87c430..ca4889d7 100644
--- a/docs/source/examples/isaacorbit/jax_ant_sac.py
+++ b/docs/source/examples/isaaclab/jax_ant_sac.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.sac import SAC, SAC_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
@@ -58,9 +41,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
@@ -70,8 +50,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0", num_envs=64)
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0", num_envs=64)
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_ant_td3.py b/docs/source/examples/isaaclab/jax_ant_td3.py
similarity index 79%
rename from docs/source/examples/isaacorbit/jax_ant_td3.py
rename to docs/source/examples/isaaclab/jax_ant_td3.py
index c414fda7..f88825a1 100644
--- a/docs/source/examples/isaacorbit/jax_ant_td3.py
+++ b/docs/source/examples/isaaclab/jax_ant_td3.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.td3 import TD3, TD3_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, Model
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
@@ -57,9 +40,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
@@ -69,8 +49,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0", num_envs=64)
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0", num_envs=64)
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_cartpole_ppo.py b/docs/source/examples/isaaclab/jax_cartpole_ppo.py
similarity index 82%
rename from docs/source/examples/isaacorbit/jax_cartpole_ppo.py
rename to docs/source/examples/isaaclab/jax_cartpole_ppo.py
index d6d5ef1b..04f3a49b 100644
--- a/docs/source/examples/isaacorbit/jax_cartpole_ppo.py
+++ b/docs/source/examples/isaaclab/jax_cartpole_ppo.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
@@ -59,9 +42,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
@@ -70,8 +50,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Cartpole-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_humanoid_ppo.py b/docs/source/examples/isaaclab/jax_humanoid_ppo.py
similarity index 82%
rename from docs/source/examples/isaacorbit/jax_humanoid_ppo.py
rename to docs/source/examples/isaaclab/jax_humanoid_ppo.py
index dafdb04b..561ccae7 100644
--- a/docs/source/examples/isaacorbit/jax_humanoid_ppo.py
+++ b/docs/source/examples/isaaclab/jax_humanoid_ppo.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(400)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(400)(inputs["states"]))
@@ -72,8 +52,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Humanoid-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Humanoid-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py b/docs/source/examples/isaaclab/jax_lift_franka_ppo.py
similarity index 82%
rename from docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
rename to docs/source/examples/isaaclab/jax_lift_franka_ppo.py
index b2d24b69..f292568a 100644
--- a/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
+++ b/docs/source/examples/isaaclab/jax_lift_franka_ppo.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -72,8 +52,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Lift-Franka-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Lift-Cube-Franka-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py b/docs/source/examples/isaaclab/jax_reach_franka_ppo.py
similarity index 82%
rename from docs/source/examples/isaacorbit/jax_reach_franka_ppo.py
rename to docs/source/examples/isaaclab/jax_reach_franka_ppo.py
index 2b01af30..180e3283 100644
--- a/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py
+++ b/docs/source/examples/isaaclab/jax_reach_franka_ppo.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -72,8 +52,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Reach-Franka-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Reach-Franka-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py b/docs/source/examples/isaaclab/jax_velocity_anymal_c_ppo.py
similarity index 80%
rename from docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py
rename to docs/source/examples/isaaclab/jax_velocity_anymal_c_ppo.py
index 0238a6ec..c92169a1 100644
--- a/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py
+++ b/docs/source/examples/isaaclab/jax_velocity_anymal_c_ppo.py
@@ -1,25 +1,11 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(inputs["states"]))
@@ -72,8 +52,8 @@ def __call__(self, inputs, role):
return x, {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Velocity-Anymal-C-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Velocity-Flat-Anymal-C-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_ant_ddpg.py b/docs/source/examples/isaaclab/torch_ant_ddpg.py
similarity index 96%
rename from docs/source/examples/isaacorbit/torch_ant_ddpg.py
rename to docs/source/examples/isaaclab/torch_ant_ddpg.py
index 7bd90ecc..56e114f3 100644
--- a/docs/source/examples/isaacorbit/torch_ant_ddpg.py
+++ b/docs/source/examples/isaaclab/torch_ant_ddpg.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ddpg import DDPG, DDPG_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model
@@ -48,8 +48,8 @@ def compute(self, inputs, role):
return self.net(torch.cat([inputs["states"], inputs["taken_actions"]], dim=1)), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0", num_envs=64)
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0", num_envs=64)
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_ant_ppo.py b/docs/source/examples/isaaclab/torch_ant_ppo.py
similarity index 90%
rename from docs/source/examples/isaacorbit/torch_ant_ppo.py
rename to docs/source/examples/isaaclab/torch_ant_ppo.py
index 54f337a8..9c0d08d7 100644
--- a/docs/source/examples/isaacorbit/torch_ant_ppo.py
+++ b/docs/source/examples/isaaclab/torch_ant_ppo.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -45,13 +45,16 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_ant_sac.py b/docs/source/examples/isaaclab/torch_ant_sac.py
similarity index 96%
rename from docs/source/examples/isaacorbit/torch_ant_sac.py
rename to docs/source/examples/isaaclab/torch_ant_sac.py
index b27909c0..089c9d1b 100644
--- a/docs/source/examples/isaacorbit/torch_ant_sac.py
+++ b/docs/source/examples/isaaclab/torch_ant_sac.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.sac import SAC, SAC_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -49,8 +49,8 @@ def compute(self, inputs, role):
return self.net(torch.cat([inputs["states"], inputs["taken_actions"]], dim=1)), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0", num_envs=64)
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0", num_envs=64)
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_ant_td3.py b/docs/source/examples/isaaclab/torch_ant_td3.py
similarity index 96%
rename from docs/source/examples/isaacorbit/torch_ant_td3.py
rename to docs/source/examples/isaaclab/torch_ant_td3.py
index 68a0a89b..253c564c 100644
--- a/docs/source/examples/isaacorbit/torch_ant_td3.py
+++ b/docs/source/examples/isaaclab/torch_ant_td3.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.td3 import TD3, TD3_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model
@@ -48,8 +48,8 @@ def compute(self, inputs, role):
return self.net(torch.cat([inputs["states"], inputs["taken_actions"]], dim=1)), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Ant-v0", num_envs=64)
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Ant-v0", num_envs=64)
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_cartpole_ppo.py b/docs/source/examples/isaaclab/torch_cartpole_ppo.py
similarity index 90%
rename from docs/source/examples/isaacorbit/torch_cartpole_ppo.py
rename to docs/source/examples/isaaclab/torch_cartpole_ppo.py
index cd0d3f78..e1cc4953 100644
--- a/docs/source/examples/isaacorbit/torch_cartpole_ppo.py
+++ b/docs/source/examples/isaaclab/torch_cartpole_ppo.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -43,13 +43,16 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Cartpole-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_humanoid_ppo.py b/docs/source/examples/isaaclab/torch_humanoid_ppo.py
similarity index 90%
rename from docs/source/examples/isaacorbit/torch_humanoid_ppo.py
rename to docs/source/examples/isaaclab/torch_humanoid_ppo.py
index c3e8f876..2625732f 100644
--- a/docs/source/examples/isaacorbit/torch_humanoid_ppo.py
+++ b/docs/source/examples/isaaclab/torch_humanoid_ppo.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -45,13 +45,16 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Humanoid-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Humanoid-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py b/docs/source/examples/isaaclab/torch_lift_franka_ppo.py
similarity index 90%
rename from docs/source/examples/isaacorbit/torch_lift_franka_ppo.py
rename to docs/source/examples/isaaclab/torch_lift_franka_ppo.py
index 406184b9..be1a444e 100644
--- a/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py
+++ b/docs/source/examples/isaaclab/torch_lift_franka_ppo.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -45,13 +45,16 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Lift-Franka-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Lift-Cube-Franka-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py b/docs/source/examples/isaaclab/torch_reach_franka_ppo.py
similarity index 90%
rename from docs/source/examples/isaacorbit/torch_reach_franka_ppo.py
rename to docs/source/examples/isaaclab/torch_reach_franka_ppo.py
index bc747c5a..a0bad78c 100644
--- a/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py
+++ b/docs/source/examples/isaaclab/torch_reach_franka_ppo.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -45,13 +45,16 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Reach-Franka-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Reach-Franka-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py b/docs/source/examples/isaaclab/torch_velocity_anymal_c_ppo.py
similarity index 89%
rename from docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py
rename to docs/source/examples/isaaclab/torch_velocity_anymal_c_ppo.py
index d5224a05..683bc095 100644
--- a/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py
+++ b/docs/source/examples/isaaclab/torch_velocity_anymal_c_ppo.py
@@ -3,7 +3,7 @@
# import the skrl components to build the RL system
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
@@ -45,13 +45,16 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
-# load and wrap the Isaac Orbit environment
-env = load_isaac_orbit_env(task_name="Isaac-Velocity-Anymal-C-v0")
+# load and wrap the Isaac Lab environment
+env = load_isaaclab_env(task_name="Isaac-Velocity-Flat-Anymal-C-v0")
env = wrap_env(env)
device = env.device
diff --git a/docs/source/examples/omniisaacgym/jax_allegro_hand_ppo.py b/docs/source/examples/omniisaacgym/jax_allegro_hand_ppo.py
index 49ca8c32..5e1c6f2e 100644
--- a/docs/source/examples/omniisaacgym/jax_allegro_hand_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_allegro_hand_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(512)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(512)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_ant_ddpg.py b/docs/source/examples/omniisaacgym/jax_ant_ddpg.py
index 8d1358ae..1460ffbd 100644
--- a/docs/source/examples/omniisaacgym/jax_ant_ddpg.py
+++ b/docs/source/examples/omniisaacgym/jax_ant_ddpg.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ddpg import DDPG, DDPG_DEFAULT_CONFIG
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
@@ -57,9 +40,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
diff --git a/docs/source/examples/omniisaacgym/jax_ant_mt_ppo.py b/docs/source/examples/omniisaacgym/jax_ant_mt_ppo.py
index 80cce179..645b6417 100644
--- a/docs/source/examples/omniisaacgym/jax_ant_mt_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_ant_mt_ppo.py
@@ -1,23 +1,9 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import threading
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -45,9 +31,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -62,9 +45,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_ant_ppo.py b/docs/source/examples/omniisaacgym/jax_ant_ppo.py
index fed94204..b9dd4d1b 100644
--- a/docs/source/examples/omniisaacgym/jax_ant_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_ant_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_ant_sac.py b/docs/source/examples/omniisaacgym/jax_ant_sac.py
index 70f5d0e8..f077f3f2 100644
--- a/docs/source/examples/omniisaacgym/jax_ant_sac.py
+++ b/docs/source/examples/omniisaacgym/jax_ant_sac.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.sac import SAC, SAC_DEFAULT_CONFIG
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
@@ -58,9 +41,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
diff --git a/docs/source/examples/omniisaacgym/jax_ant_td3.py b/docs/source/examples/omniisaacgym/jax_ant_td3.py
index ec49187c..555c451d 100644
--- a/docs/source/examples/omniisaacgym/jax_ant_td3.py
+++ b/docs/source/examples/omniisaacgym/jax_ant_td3.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.td3 import TD3, TD3_DEFAULT_CONFIG
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
@@ -57,9 +40,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
diff --git a/docs/source/examples/omniisaacgym/jax_anymal_ppo.py b/docs/source/examples/omniisaacgym/jax_anymal_ppo.py
index d09cf2df..fb9bc889 100644
--- a/docs/source/examples/omniisaacgym/jax_anymal_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_anymal_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_anymal_terrain_ppo.py b/docs/source/examples/omniisaacgym/jax_anymal_terrain_ppo.py
index 7401460a..3c92265b 100644
--- a/docs/source/examples/omniisaacgym/jax_anymal_terrain_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_anymal_terrain_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(512)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(512)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_ball_balance_ppo.py b/docs/source/examples/omniisaacgym/jax_ball_balance_ppo.py
index 21057dc3..43cef3a7 100644
--- a/docs/source/examples/omniisaacgym/jax_ball_balance_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_ball_balance_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_cartpole_mt_ppo.py b/docs/source/examples/omniisaacgym/jax_cartpole_mt_ppo.py
index ac719582..73c3798a 100644
--- a/docs/source/examples/omniisaacgym/jax_cartpole_mt_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_cartpole_mt_ppo.py
@@ -1,23 +1,9 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import threading
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -45,9 +31,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
@@ -61,9 +44,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_cartpole_ppo.py b/docs/source/examples/omniisaacgym/jax_cartpole_ppo.py
index 4649e938..34aca8e4 100644
--- a/docs/source/examples/omniisaacgym/jax_cartpole_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_cartpole_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
@@ -59,9 +42,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_crazyflie_ppo.py b/docs/source/examples/omniisaacgym/jax_crazyflie_ppo.py
index 52a222c6..a10f3a76 100644
--- a/docs/source/examples/omniisaacgym/jax_crazyflie_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_crazyflie_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.tanh(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.tanh(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_factory_task_nut_bolt_pick_ppo.py b/docs/source/examples/omniisaacgym/jax_factory_task_nut_bolt_pick_ppo.py
index 44fd2b15..37efe804 100644
--- a/docs/source/examples/omniisaacgym/jax_factory_task_nut_bolt_pick_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_factory_task_nut_bolt_pick_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -59,9 +42,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_franka_cabinet_ppo.py b/docs/source/examples/omniisaacgym/jax_franka_cabinet_ppo.py
index a773e13a..f21c4822 100644
--- a/docs/source/examples/omniisaacgym/jax_franka_cabinet_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_franka_cabinet_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_humanoid_ppo.py b/docs/source/examples/omniisaacgym/jax_humanoid_ppo.py
index 11c92135..963c44e2 100644
--- a/docs/source/examples/omniisaacgym/jax_humanoid_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_humanoid_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(400)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(400)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_ingenuity_ppo.py b/docs/source/examples/omniisaacgym/jax_ingenuity_ppo.py
index 5413e1ad..a976afb3 100644
--- a/docs/source/examples/omniisaacgym/jax_ingenuity_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_ingenuity_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_quadcopter_ppo.py b/docs/source/examples/omniisaacgym/jax_quadcopter_ppo.py
index 0d21d945..a5ab2e75 100644
--- a/docs/source/examples/omniisaacgym/jax_quadcopter_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_quadcopter_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
@@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/jax_shadow_hand_ppo.py b/docs/source/examples/omniisaacgym/jax_shadow_hand_ppo.py
index 97c530d7..bbaeebfa 100644
--- a/docs/source/examples/omniisaacgym/jax_shadow_hand_ppo.py
+++ b/docs/source/examples/omniisaacgym/jax_shadow_hand_ppo.py
@@ -1,21 +1,7 @@
-"""
-Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
- * Python 3.7 is only supported up to jax<=0.3.25.
- See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
- * Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
- See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
- * The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
- * Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
-"""
-
import flax.linen as nn
import jax
import jax.numpy as jnp
-
-jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier
-
# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
@@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(512)(inputs["states"]))
@@ -61,9 +44,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)
- def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
- return id(self)
-
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(512)(inputs["states"]))
diff --git a/docs/source/examples/omniisaacgym/torch_allegro_hand_ppo.py b/docs/source/examples/omniisaacgym/torch_allegro_hand_ppo.py
index 9559c948..5e8cbf85 100644
--- a/docs/source/examples/omniisaacgym/torch_allegro_hand_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_allegro_hand_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_ant_mt_ppo.py b/docs/source/examples/omniisaacgym/torch_ant_mt_ppo.py
index b17598b4..351401d8 100644
--- a/docs/source/examples/omniisaacgym/torch_ant_mt_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_ant_mt_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the multi-threaded Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_ant_ppo.py b/docs/source/examples/omniisaacgym/torch_ant_ppo.py
index 09697cb6..a7637ae1 100644
--- a/docs/source/examples/omniisaacgym/torch_ant_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_ant_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_anymal_ppo.py b/docs/source/examples/omniisaacgym/torch_anymal_ppo.py
index 804b0b28..076ee05a 100644
--- a/docs/source/examples/omniisaacgym/torch_anymal_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_anymal_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_ball_balance_ppo.py b/docs/source/examples/omniisaacgym/torch_ball_balance_ppo.py
index c16ee623..d70df91e 100644
--- a/docs/source/examples/omniisaacgym/torch_ball_balance_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_ball_balance_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_cartpole_mt_ppo.py b/docs/source/examples/omniisaacgym/torch_cartpole_mt_ppo.py
index f902acb9..f4bc4da4 100644
--- a/docs/source/examples/omniisaacgym/torch_cartpole_mt_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_cartpole_mt_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the multi-threaded Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_cartpole_ppo.py b/docs/source/examples/omniisaacgym/torch_cartpole_ppo.py
index 9dd137e7..8c660587 100644
--- a/docs/source/examples/omniisaacgym/torch_cartpole_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_cartpole_ppo.py
@@ -43,9 +43,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_crazyflie_ppo.py b/docs/source/examples/omniisaacgym/torch_crazyflie_ppo.py
index c7f71e27..99db4421 100644
--- a/docs/source/examples/omniisaacgym/torch_crazyflie_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_crazyflie_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_factory_task_nut_bolt_pick_ppo.py b/docs/source/examples/omniisaacgym/torch_factory_task_nut_bolt_pick_ppo.py
index e236ad5c..1e2510e2 100644
--- a/docs/source/examples/omniisaacgym/torch_factory_task_nut_bolt_pick_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_factory_task_nut_bolt_pick_ppo.py
@@ -44,9 +44,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_franka_cabinet_ppo.py b/docs/source/examples/omniisaacgym/torch_franka_cabinet_ppo.py
index e81b5bfc..330ef2d5 100644
--- a/docs/source/examples/omniisaacgym/torch_franka_cabinet_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_franka_cabinet_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_humanoid_ppo.py b/docs/source/examples/omniisaacgym/torch_humanoid_ppo.py
index 3d6b49a2..bdfb4078 100644
--- a/docs/source/examples/omniisaacgym/torch_humanoid_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_humanoid_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_ingenuity_ppo.py b/docs/source/examples/omniisaacgym/torch_ingenuity_ppo.py
index 257c94ad..e68b34be 100644
--- a/docs/source/examples/omniisaacgym/torch_ingenuity_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_ingenuity_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_quadcopter_ppo.py b/docs/source/examples/omniisaacgym/torch_quadcopter_ppo.py
index c94c4cd8..4747536e 100644
--- a/docs/source/examples/omniisaacgym/torch_quadcopter_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_quadcopter_ppo.py
@@ -45,9 +45,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/examples/omniisaacgym/torch_shadow_hand_ppo.py b/docs/source/examples/omniisaacgym/torch_shadow_hand_ppo.py
index 5e64f99c..e708d41c 100644
--- a/docs/source/examples/omniisaacgym/torch_shadow_hand_ppo.py
+++ b/docs/source/examples/omniisaacgym/torch_shadow_hand_ppo.py
@@ -47,9 +47,12 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
- return self.value_layer(self.net(inputs["states"])), {}
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.value_layer(shared_output), {}
# load and wrap the Omniverse Isaac Gym environment
diff --git a/docs/source/index.rst b/docs/source/index.rst
index d587c727..f7983c1c 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -28,15 +28,15 @@ SKRL - Reinforcement Learning library (|version|)
-**skrl** is an open-source library for Reinforcement Learning written in Python (on top of `PyTorch `_ and `JAX `_) and designed with a focus on modularity, readability, simplicity and transparency of algorithm implementation. In addition to supporting the OpenAI `Gym `_ / Farama `Gymnasium `_, `DeepMind `_ and other environment interfaces, it allows loading and configuring `NVIDIA Isaac Gym `_, `NVIDIA Isaac Orbit `_ and `NVIDIA Omniverse Isaac Gym `_ environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run.
+**skrl** is an open-source library for Reinforcement Learning written in Python (on top of `PyTorch `_ and `JAX `_) and designed with a focus on modularity, readability, simplicity and transparency of algorithm implementation. In addition to supporting the OpenAI `Gym `_ / Farama `Gymnasium `_, `DeepMind `_ and other environment interfaces, it allows loading and configuring `NVIDIA Isaac Gym `_, `NVIDIA Omniverse Isaac Gym `_, and `NVIDIA Isaac Lab `_ environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run.
**Main features:**
* PyTorch (|_1| |pytorch| |_1|) and JAX (|_1| |jax| |_1|)
* Clean code
* Modularity and reusability
* Documented library, code and implementations
- * Support for Gym/Gymnasium (single and vectorized), DeepMind, NVIDIA Isaac Gym (preview 2, 3 and 4), NVIDIA Isaac Orbit, NVIDIA Omniverse Isaac Gym environments, among others
- * Simultaneous learning by scopes in Gym/Gymnasium (vectorized), NVIDIA Isaac Gym, NVIDIA Isaac Orbit and NVIDIA Omniverse Isaac Gym
+ * Support for Gym/Gymnasium (single and vectorized), DeepMind, NVIDIA Isaac Gym (preview 2, 3 and 4), NVIDIA Omniverse Isaac Gym environments, NVIDIA Isaac Lab, among others
+ * Simultaneous learning by scopes in Gym/Gymnasium (vectorized), NVIDIA Isaac Gym, NVIDIA Omniverse Isaac Gym, and NVIDIA Isaac Lab
.. raw:: html
@@ -132,13 +132,13 @@ Multi-agents
Environments
^^^^^^^^^^^^
- Definition of the Isaac Gym (preview 2, 3 and 4), Isaac Orbit and Omniverse Isaac Gym environment loaders, and wrappers for the Gym/Gymnasium, DeepMind, Isaac Gym, Isaac Orbit, Omniverse Isaac Gym environments, among others
+ Definition of the Isaac Gym (preview 2, 3 and 4), Omniverse Isaac Gym, and Isaac Lab environment loaders, and wrappers for the Gym/Gymnasium, DeepMind, Isaac Gym, Omniverse Isaac Gym environments, Isaac Lab, among others
- * :doc:`Single-agent environment wrapping ` for **Gym/Gymnasium**, **DeepMind**, **Isaac Gym**, **Isaac Orbit**, **Omniverse Isaac Gym** environments, among others
+ * :doc:`Single-agent environment wrapping ` for **Gym/Gymnasium**, **DeepMind**, **Isaac Gym**, **Omniverse Isaac Gym**, **Isaac Lab** environments, among others
* :doc:`Multi-agent environment wrapping ` for **PettingZoo** and **Bi-DexHands** environments
* Loading :doc:`Isaac Gym environments `
- * Loading :doc:`Isaac Orbit environments `
* Loading :doc:`Omniverse Isaac Gym environments `
+ * Loading :doc:`Isaac Lab environments `
Memories
^^^^^^^^
diff --git a/docs/source/intro/data.rst b/docs/source/intro/data.rst
index a8499bb2..d0509a77 100644
--- a/docs/source/intro/data.rst
+++ b/docs/source/intro/data.rst
@@ -33,14 +33,14 @@ Each agent offers the following parameters under the :literal:`"experiment"` key
.. literalinclude:: ../snippets/data.py
:language: python
:emphasize-lines: 5-7
- :start-after: [start-tensorboard-configuration]
- :end-before: [end-tensorboard-configuration]
+ :start-after: [start-data-configuration]
+ :end-before: [end-data-configuration]
* **directory**: directory path where the data generated by the experiments (a subdirectory) are stored. If no value is set, the :literal:`runs` folder (inside the current working directory) will be used (and created if it does not exist).
* **experiment_name**: name of the experiment (subdirectory). If no value is set, it will be the current date and time and the agent's name (e.g. :literal:`22-01-09_22-48-49-816281_DDPG`).
-* **write_interval**: interval for writing metrics and values to TensorBoard (default is 250 timesteps). A value equal to or less than 0 disables tracking and writing to TensorBoard.
+* **write_interval**: interval for writing metrics and values to TensorBoard. A value equal to or less than 0 disables tracking and writing to TensorBoard. If set to ``"auto"`` (default value), the interval will be defined to collect 100 samples throughout training/evaluation (``timesteps / 100``).
.. raw:: html
@@ -172,8 +172,8 @@ Each agent offers the following parameters under the :literal:`"experiment"` key
.. literalinclude:: ../snippets/data.py
:language: python
:emphasize-lines: 12-13
- :start-after: [start-wandb-configuration]
- :end-before: [end-wandb-configuration]
+ :start-after: [start-data-configuration]
+ :end-before: [end-data-configuration]
* **wandb**: whether to enable support for Weights & Biases.
@@ -206,10 +206,10 @@ The checkpoint management, as in the previous case, is the responsibility of the
.. literalinclude:: ../snippets/data.py
:language: python
:emphasize-lines: 9,10
- :start-after: [start-checkpoint-configuration]
- :end-before: [end-checkpoint-configuration]
+ :start-after: [start-data-configuration]
+ :end-before: [end-data-configuration]
-* **checkpoint_interval**: interval for checkpoints (default is 1000 timesteps). A value equal to or less than 0 disables the checkpoint creation.
+* **checkpoint_interval**: interval for checkpoints. A value equal to or less than 0 disables the checkpoint creation. If set to ``"auto"`` (default value), the interval will be defined to collect 10 checkpoints throughout training/evaluation (``timesteps / 10``).
* **store_separately**: if set to :literal:`True`, all the modules that an agent contains (models, optimizers, preprocessors, etc.) will be saved each one in a separate file. By default (:literal:`False`) the modules are grouped in a dictionary and stored in the same file.
diff --git a/docs/source/intro/examples.rst b/docs/source/intro/examples.rst
index 2f16383e..88441672 100644
--- a/docs/source/intro/examples.rst
+++ b/docs/source/intro/examples.rst
@@ -772,28 +772,28 @@ The agent configuration is mapped, as far as possible, from the `IsaacGymEnvs co
-**NVIDIA Isaac Orbit**
-----------------------
+**NVIDIA Isaac Lab**
+--------------------
.. raw:: html
-Isaac Orbit environments
-^^^^^^^^^^^^^^^^^^^^^^^^
+Isaac Lab environments
+^^^^^^^^^^^^^^^^^^^^^^
-Training/evaluation of an agent in `Isaac Orbit environments `_ (**one agent, multiple environments**)
+Training/evaluation of an agent in `Isaac Lab environments `_ (**one agent, multiple environments**)
-.. image:: ../_static/imgs/example_isaac_orbit.png
+.. image:: ../_static/imgs/example_isaaclab.png
:width: 100%
:align: center
- :alt: Isaac Orbit environments
+ :alt: Isaac Lab environments
.. raw:: html
-The agent configuration is mapped, as far as possible, from the `Isaac Orbit configuration `_ for rl_games. Shared models or separated models are used depending on the value of the :literal:`network.separate` variable. The following list shows the mapping between the two configurations:
+The agent configuration is mapped, as far as possible, from the Isaac Lab configuration for rl_games. Shared models or separated models are used depending on the value of the :literal:`network.separate` variable. The following list shows the mapping between the two configurations:
.. tabs::
@@ -853,11 +853,11 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con
# trainer
timesteps = num_steps_per_episode * max_epochs
-**Benchmark results** are listed in `Benchmark results #32 (NVIDIA Isaac Orbit) `_
+**Benchmark results** are listed in `Benchmark results #32 (NVIDIA Isaac Lab) `_
.. note::
- Isaac Orbit environments implement a functionality to get their configuration from the command line. Because of this feature, setting the :literal:`headless` option from the trainer configuration will not work. In this case, it is necessary to invoke the scripts as follows: :literal:`orbit -p script.py --headless`
+ Isaac Lab environments implement a functionality to get their configuration from the command line. Because of this feature, setting the :literal:`headless` option from the trainer configuration will not work. In this case, it is necessary to invoke the scripts as follows: :literal:`isaaclab -p script.py --headless`
.. tabs::
@@ -873,28 +873,28 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con
- Script
- Checkpoint (Hugging Face)
* - Isaac-Ant-v0
- - :download:`torch_ant_ppo.py <../examples/isaacorbit/torch_ant_ppo.py>`
- |br| :download:`torch_ant_ddpg.py <../examples/isaacorbit/torch_ant_ddpg.py>`
- |br| :download:`torch_ant_td3.py <../examples/isaacorbit/torch_ant_td3.py>`
- |br| :download:`torch_ant_sac.py <../examples/isaacorbit/torch_ant_sac.py>`
+ - :download:`torch_ant_ppo.py <../examples/isaaclab/torch_ant_ppo.py>`
+ |br| :download:`torch_ant_ddpg.py <../examples/isaaclab/torch_ant_ddpg.py>`
+ |br| :download:`torch_ant_td3.py <../examples/isaaclab/torch_ant_td3.py>`
+ |br| :download:`torch_ant_sac.py <../examples/isaaclab/torch_ant_sac.py>`
- `IsaacOrbit-Isaac-Ant-v0-PPO `_
|br|
|br|
|br|
* - Isaac-Cartpole-v0
- - :download:`torch_cartpole_ppo.py <../examples/isaacorbit/torch_cartpole_ppo.py>`
+ - :download:`torch_cartpole_ppo.py <../examples/isaaclab/torch_cartpole_ppo.py>`
- `IsaacOrbit-Isaac-Cartpole-v0-PPO `_
* - Isaac-Humanoid-v0
- - :download:`torch_humanoid_ppo.py <../examples/isaacorbit/torch_humanoid_ppo.py>`
+ - :download:`torch_humanoid_ppo.py <../examples/isaaclab/torch_humanoid_ppo.py>`
- `IsaacOrbit-Isaac-Humanoid-v0-PPO `_
* - Isaac-Lift-Franka-v0
- - :download:`torch_lift_franka_ppo.py <../examples/isaacorbit/torch_lift_franka_ppo.py>`
+ - :download:`torch_lift_franka_ppo.py <../examples/isaaclab/torch_lift_franka_ppo.py>`
- `IsaacOrbit-Isaac-Lift-Franka-v0-PPO `_
* - Isaac-Reach-Franka-v0
- - :download:`torch_reach_franka_ppo.py <../examples/isaacorbit/torch_reach_franka_ppo.py>`
+ - :download:`torch_reach_franka_ppo.py <../examples/isaaclab/torch_reach_franka_ppo.py>`
- `IsaacOrbit-Isaac-Reach-Franka-v0-PPO `_
* - Isaac-Velocity-Anymal-C-v0
- - :download:`torch_velocity_anymal_c_ppo.py <../examples/isaacorbit/torch_velocity_anymal_c_ppo.py>`
+ - :download:`torch_velocity_anymal_c_ppo.py <../examples/isaaclab/torch_velocity_anymal_c_ppo.py>`
-
.. group-tab:: |_4| |jax| |_4|
@@ -909,28 +909,28 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con
- Script
- Checkpoint (Hugging Face)
* - Isaac-Ant-v0
- - :download:`jax_ant_ppo.py <../examples/isaacorbit/jax_ant_ppo.py>`
- |br| :download:`jax_ant_ddpg.py <../examples/isaacorbit/jax_ant_ddpg.py>`
- |br| :download:`jax_ant_td3.py <../examples/isaacorbit/jax_ant_td3.py>`
- |br| :download:`jax_ant_sac.py <../examples/isaacorbit/jax_ant_sac.py>`
+ - :download:`jax_ant_ppo.py <../examples/isaaclab/jax_ant_ppo.py>`
+ |br| :download:`jax_ant_ddpg.py <../examples/isaaclab/jax_ant_ddpg.py>`
+ |br| :download:`jax_ant_td3.py <../examples/isaaclab/jax_ant_td3.py>`
+ |br| :download:`jax_ant_sac.py <../examples/isaaclab/jax_ant_sac.py>`
- `IsaacOrbit-Isaac-Ant-v0-PPO `_
|br|
|br|
|br|
* - Isaac-Cartpole-v0
- - :download:`jax_cartpole_ppo.py <../examples/isaacorbit/jax_cartpole_ppo.py>`
+ - :download:`jax_cartpole_ppo.py <../examples/isaaclab/jax_cartpole_ppo.py>`
- `IsaacOrbit-Isaac-Cartpole-v0-PPO `_
* - Isaac-Humanoid-v0
- - :download:`jax_humanoid_ppo.py <../examples/isaacorbit/jax_humanoid_ppo.py>`
+ - :download:`jax_humanoid_ppo.py <../examples/isaaclab/jax_humanoid_ppo.py>`
- `IsaacOrbit-Isaac-Humanoid-v0-PPO `_
* - Isaac-Lift-Franka-v0
- - :download:`jax_lift_franka_ppo.py <../examples/isaacorbit/jax_lift_franka_ppo.py>`
+ - :download:`jax_lift_franka_ppo.py <../examples/isaaclab/jax_lift_franka_ppo.py>`
- `IsaacOrbit-Isaac-Lift-Franka-v0-PPO `_
* - Isaac-Reach-Franka-v0
- - :download:`jax_reach_franka_ppo.py <../examples/isaacorbit/jax_reach_franka_ppo.py>`
+ - :download:`jax_reach_franka_ppo.py <../examples/isaaclab/jax_reach_franka_ppo.py>`
- `IsaacOrbit-Isaac-Reach-Franka-v0-PPO `_
* - Isaac-Velocity-Anymal-C-v0
- - :download:`jax_velocity_anymal_c_ppo.py <../examples/isaacorbit/jax_velocity_anymal_c_ppo.py>`
+ - :download:`jax_velocity_anymal_c_ppo.py <../examples/isaaclab/jax_velocity_anymal_c_ppo.py>`
-
.. raw:: html
diff --git a/docs/source/intro/getting_started.rst b/docs/source/intro/getting_started.rst
index 75320efe..a957b8eb 100644
--- a/docs/source/intro/getting_started.rst
+++ b/docs/source/intro/getting_started.rst
@@ -43,7 +43,7 @@ At each step (also called timestep) of interaction with the environment, the age
The environment plays a fundamental role in the definition of the RL schema. For example, the selection of the agent depends strongly on the observation and action space nature. There are several interfaces to interact with the environments such as OpenAI Gym / Farama Gymnasium or DeepMind. However, each of them has a different API and work with non-compatible data types.
-* For **single-agent** environments, skrl offers a function to **wrap environments** based on the Gym/Gymnasium, DeepMind, NVIDIA Isaac Gym, Isaac Orbit and Omniverse Isaac Gym interfaces, among others. The wrapped environments provide, to the library components, a common interface (based on Gym/Gymnasium) as shown in the following figure. Refer to the :doc:`Wrapping (single-agent) <../api/envs/wrapping>` section for more information.
+* For **single-agent** environments, skrl offers a function to **wrap environments** based on the Gym/Gymnasium, DeepMind, NVIDIA Isaac Gym, Omniverse Isaac Gym and Isaac Lab interfaces, among others. The wrapped environments provide, to the library components, a common interface (based on Gym/Gymnasium) as shown in the following figure. Refer to the :doc:`Wrapping (single-agent) <../api/envs/wrapping>` section for more information.
* For **multi-agent** environments, skrl offers a function to **wrap environments** based on the PettingZoo and Bi-DexHands interfaces. The wrapped environments provide, to the library components, a common interface (based on PettingZoo) as shown in the following figure. Refer to the :doc:`Wrapping (multi-agents) <../api/envs/multi_agents_wrapping>` section for more information.
@@ -85,6 +85,24 @@ Among the methods and properties defined in the wrapped environment, the observa
.. tabs::
+ .. tab:: Isaac Lab
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../snippets/wrapping.py
+ :language: python
+ :start-after: [pytorch-start-isaaclab]
+ :end-before: [pytorch-end-isaaclab]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../snippets/wrapping.py
+ :language: python
+ :start-after: [jax-start-isaaclab]
+ :end-before: [jax-end-isaaclab]
+
.. tab:: Omniverse Isaac Gym
.. tabs::
@@ -125,24 +143,6 @@ Among the methods and properties defined in the wrapped environment, the observa
:start-after: [jax-start-omniverse-isaacgym-mt]
:end-before: [jax-end-omniverse-isaacgym-mt]
- .. tab:: Isaac Orbit
-
- .. tabs::
-
- .. group-tab:: |_4| |pytorch| |_4|
-
- .. literalinclude:: ../snippets/wrapping.py
- :language: python
- :start-after: [pytorch-start-isaac-orbit]
- :end-before: [pytorch-end-isaac-orbit]
-
- .. group-tab:: |_4| |jax| |_4|
-
- .. literalinclude:: ../snippets/wrapping.py
- :language: python
- :start-after: [jax-start-isaac-orbit]
- :end-before: [jax-end-isaac-orbit]
-
.. tab:: Isaac Gym
.. tabs::
diff --git a/docs/source/intro/installation.rst b/docs/source/intro/installation.rst
index c59e0e06..1c8fd3fc 100644
--- a/docs/source/intro/installation.rst
+++ b/docs/source/intro/installation.rst
@@ -191,7 +191,7 @@ Bug detection and/or correction, feature requests and everything else are more t
AttributeError: 'Adam' object has no attribute '_warned_capturable_if_run_uncaptured'
-2. When installing the JAX version in Python 3.7 (e.g. OmniIsaacGymEnvs or Isaac Orbit on Isaac Sim 2022.2.1 and earlier).
+2. When installing the JAX version in Python 3.7 (e.g. OmniIsaacGymEnvs or Isaac Lab on Isaac Sim 2022.2.1 and earlier).
.. code-block:: text
@@ -210,7 +210,7 @@ Bug detection and/or correction, feature requests and everything else are more t
* Overload models ``__hash__`` method to avoid :literal:`"TypeError: Failed to hash Flax Module"`.
-3. When training/evaluating using JAX in Python 3.7 (e.g. OmniIsaacGymEnvs or Isaac Orbit on Isaac Sim 2022.2.1 and earlier).
+3. When training/evaluating using JAX in Python 3.7 (e.g. OmniIsaacGymEnvs or Isaac Lab on Isaac Sim 2022.2.1 and earlier).
.. code-block:: text
@@ -223,7 +223,7 @@ Bug detection and/or correction, feature requests and everything else are more t
def __hash__(self):
return id(self)
-4. When training/evaluating using JAX with the NVIDIA Isaac Gym Preview, Isaac Orbit or Omniverse Isaac Gym environments.
+4. When training/evaluating using JAX with the NVIDIA Isaac Gym Preview, Omniverse Isaac Gym or Isaac Lab environments.
.. code-block:: text
diff --git a/docs/source/snippets/data.py b/docs/source/snippets/data.py
index 92782d79..6655a5f9 100644
--- a/docs/source/snippets/data.py
+++ b/docs/source/snippets/data.py
@@ -1,58 +1,20 @@
-# [start-tensorboard-configuration]
+# [start-data-configuration]
DEFAULT_CONFIG = {
# ...
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
"wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}
-# [end-tensorboard-configuration]
-
-
-# [start-wandb-configuration]
-DEFAULT_CONFIG = {
- # ...
-
- "experiment": {
- "directory": "", # experiment's parent directory
- "experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
-
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
- "store_separately": False, # whether to store checkpoints separately
-
- "wandb": False, # whether to use Weights & Biases
- "wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
- }
-}
-# [end-wandb-configuration]
-
-
-# [start-checkpoint-configuration]
-DEFAULT_CONFIG = {
- # ...
-
- "experiment": {
- "directory": "", # experiment's parent directory
- "experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
-
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
- "store_separately": False, # whether to store checkpoints separately
-
- "wandb": False, # whether to use Weights & Biases
- "wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
- }
-}
-# [end-checkpoint-configuration]
+# [end-data-configuration]
# [start-checkpoint-load-agent-torch]
diff --git a/docs/source/snippets/loaders.py b/docs/source/snippets/loaders.py
index f9a15f5a..49a8cc41 100644
--- a/docs/source/snippets/loaders.py
+++ b/docs/source/snippets/loaders.py
@@ -111,40 +111,40 @@
# =============================================================================
-# [start-isaac-orbit-envs-parameters-torch]
+# [start-isaaclab-envs-parameters-torch]
# import the environment loader
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
# load environment
-env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
-# [end-isaac-orbit-envs-parameters-torch]
+env = load_isaaclab_env(task_name="Isaac-Cartpole-v0")
+# [end-isaaclab-envs-parameters-torch]
-# [start-isaac-orbit-envs-parameters-jax]
+# [start-isaaclab-envs-parameters-jax]
# import the environment loader
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
# load environment
-env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
-# [end-isaac-orbit-envs-parameters-jax]
+env = load_isaaclab_env(task_name="Isaac-Cartpole-v0")
+# [end-isaaclab-envs-parameters-jax]
-# [start-isaac-orbit-envs-cli-torch]
+# [start-isaaclab-envs-cli-torch]
# import the environment loader
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
# load environment
-env = load_isaac_orbit_env()
-# [end-isaac-orbit-envs-cli-torch]
+env = load_isaaclab_env()
+# [end-isaaclab-envs-cli-torch]
-# [start-isaac-orbit-envs-cli-jax]
+# [start-isaaclab-envs-cli-jax]
# import the environment loader
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
# load environment
-env = load_isaac_orbit_env()
-# [end-isaac-orbit-envs-cli-jax]
+env = load_isaaclab_env()
+# [end-isaaclab-envs-cli-jax]
# =============================================================================
diff --git a/docs/source/snippets/shared_model.py b/docs/source/snippets/shared_model.py
index 5fd8c737..8a0ed0b4 100644
--- a/docs/source/snippets/shared_model.py
+++ b/docs/source/snippets/shared_model.py
@@ -1,4 +1,59 @@
-# [start-mlp-torch]
+# [start-mlp-single-forward-pass-torch]
+import torch
+import torch.nn as nn
+
+from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
+
+
+# define the shared model
+class SharedModel(GaussianMixin, DeterministicMixin, Model):
+ def __init__(self, observation_space, action_space, device, clip_actions=False,
+ clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
+ Model.__init__(self, observation_space, action_space, device)
+ GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction, role="policy")
+ DeterministicMixin.__init__(self, clip_actions, role="value")
+
+ # shared layers/network
+ self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
+ nn.ELU(),
+ nn.Linear(32, 32),
+ nn.ELU())
+
+ # separated layers ("policy")
+ self.mean_layer = nn.Linear(32, self.num_actions)
+ self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
+
+ # separated layer ("value")
+ self.value_layer = nn.Linear(32, 1)
+
+ # override the .act(...) method to disambiguate its call
+ def act(self, inputs, role):
+ if role == "policy":
+ return GaussianMixin.act(self, inputs, role)
+ elif role == "value":
+ return DeterministicMixin.act(self, inputs, role)
+
+ # forward the input to compute model output according to the specified role
+ def compute(self, inputs, role):
+ if role == "policy":
+ # save shared layers/network output to perform a single forward-pass
+ self._shared_output = self.net(inputs["states"])
+ return self.mean_layer(self._shared_output), self.log_std_parameter, {}
+ elif role == "value":
+ # use saved shared layers/network output to perform a single forward-pass, if it was saved
+ shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
+ self._shared_output = None # reset saved shared output to prevent the use of erroneous data in subsequent steps
+ return self.value_layer(shared_output), {}
+
+
+# instantiate the shared model and pass the same instance to the other key
+models = {}
+models["policy"] = SharedModel(env.observation_space, env.action_space, env.device)
+models["value"] = models["policy"]
+# [end-mlp-single-forward-pass-torch]
+
+
+# [start-mlp-multi-forward-pass-torch]
import torch
import torch.nn as nn
@@ -45,4 +100,4 @@ def compute(self, inputs, role):
models = {}
models["policy"] = SharedModel(env.observation_space, env.action_space, env.device)
models["value"] = models["policy"]
-# [end-mlp-torch]
+# [end-mlp-multi-forward-pass-torch]
diff --git a/docs/source/snippets/wrapping.py b/docs/source/snippets/wrapping.py
index 4213ad57..564c1c3a 100644
--- a/docs/source/snippets/wrapping.py
+++ b/docs/source/snippets/wrapping.py
@@ -49,30 +49,30 @@
# [jax-end-omniverse-isaacgym-mt]
-# [pytorch-start-isaac-orbit]
+# [pytorch-start-isaaclab]
# import the environment wrapper and loader
from skrl.envs.wrappers.torch import wrap_env
-from skrl.envs.loaders.torch import load_isaac_orbit_env
+from skrl.envs.loaders.torch import load_isaaclab_env
# load the environment
-env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
+env = load_isaaclab_env(task_name="Isaac-Cartpole-v0")
# wrap the environment
-env = wrap_env(env) # or 'env = wrap_env(env, wrapper="isaac-orbit")'
-# [pytorch-end-isaac-orbit]
+env = wrap_env(env) # or 'env = wrap_env(env, wrapper="isaaclab")'
+# [pytorch-end-isaaclab]
-# [jax-start-isaac-orbit]
+# [jax-start-isaaclab]
# import the environment wrapper and loader
from skrl.envs.wrappers.jax import wrap_env
-from skrl.envs.loaders.jax import load_isaac_orbit_env
+from skrl.envs.loaders.jax import load_isaaclab_env
# load the environment
-env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
+env = load_isaaclab_env(task_name="Isaac-Cartpole-v0")
# wrap the environment
-env = wrap_env(env) # or 'env = wrap_env(env, wrapper="isaac-orbit")'
-# [jax-end-isaac-orbit]
+env = wrap_env(env) # or 'env = wrap_env(env, wrapper="isaaclab")'
+# [jax-end-isaaclab]
# [pytorch-start-isaacgym-preview4-make]
diff --git a/pyproject.toml b/pyproject.toml
index 6de499cb..cceeb007 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "skrl"
-version = "1.1.0"
+version = "1.2.0"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.6"
diff --git a/skrl/__init__.py b/skrl/__init__.py
index 7a027642..e3e4fa0c 100644
--- a/skrl/__init__.py
+++ b/skrl/__init__.py
@@ -1,6 +1,7 @@
from typing import Union
import logging
+import os
import sys
import numpy as np
@@ -43,6 +44,69 @@ class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""
+
+ class PyTorch(object):
+ def __init__(self) -> None:
+ """PyTorch configuration
+ """
+ self._device = None
+ # torch.distributed config
+ self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ self._rank = int(os.getenv("RANK", "0"))
+ self._world_size = int(os.getenv("WORLD_SIZE", "1"))
+ self._is_distributed = self._world_size > 1
+
+ @property
+ def local_rank(self) -> int:
+ """The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
+
+ This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist)
+ """
+ return self._local_rank
+
+ @property
+ def rank(self) -> int:
+ """The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
+
+ This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist)
+ """
+ return self._rank
+
+ @property
+ def world_size(self) -> int:
+ """The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
+
+ This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist)
+ """
+ return self._world_size
+
+ @property
+ def is_distributed(self) -> bool:
+ """Whether if running in a distributed environment
+
+ This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
+ """
+ return self._is_distributed
+
+ @property
+ def device(self) -> "torch.device":
+ """Default device
+
+ The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
+ if CUDA is available, ``cpu`` otherwise
+ """
+ try:
+ import torch
+ if self._device is None:
+ return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
+ return torch.device(self._device)
+ except ImportError:
+ return self._device
+
+ @device.setter
+ def device(self, device: Union[str, "torch.device"]) -> None:
+ self._device = device
+
class JAX(object):
def __init__(self) -> None:
"""JAX configuration
@@ -72,7 +136,8 @@ def key(self) -> "jax.Array":
if isinstance(self._key, np.ndarray):
try:
import jax
- self._key = jax.random.PRNGKey(self._key[1])
+ with jax.default_device(jax.devices("cpu")[0]):
+ self._key = jax.random.PRNGKey(self._key[1])
except ImportError:
pass
return self._key
@@ -83,11 +148,13 @@ def key(self, value: Union[int, "jax.Array"]) -> None:
# don't import JAX if it has not been imported before
if "jax" in sys.modules:
import jax
- value = jax.random.PRNGKey(value)
+ with jax.default_device(jax.devices("cpu")[0]):
+ value = jax.random.PRNGKey(value)
else:
value = np.array([0, value], dtype=np.uint32)
self._key = value
self.jax = JAX()
+ self.torch = PyTorch()
config = _Config()
diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py
index fa1acb7b..378a84b7 100644
--- a/skrl/agents/jax/a2c/a2c.py
+++ b/skrl/agents/jax/a2c/a2c.py
@@ -46,9 +46,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -251,8 +251,9 @@ def __init__(self,
else:
self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
- self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
- self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
+ with jax.default_device(self.device):
+ self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
+ self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
self.checkpoint_modules["value_optimizer"] = self.value_optimizer
diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py
index 71e9d091..f5de7d2a 100644
--- a/skrl/agents/jax/base.py
+++ b/skrl/agents/jax/base.py
@@ -53,7 +53,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
- self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
+ self.device = device
+ if type(device) == str:
+ device_type, device_index = f"{device}:0".split(':')[:2]
+ self.device = jax.devices(device_type)[int(device_index)]
if type(memory) is list:
self.memory = memory[0]
@@ -68,7 +71,7 @@ def __init__(self,
pass
self.tracking_data = collections.defaultdict(list)
- self.write_interval = self.cfg.get("experiment", {}).get("write_interval", 1000)
+ self.write_interval = self.cfg.get("experiment", {}).get("write_interval", "auto")
self._track_rewards = collections.deque(maxlen=100)
self._track_timesteps = collections.deque(maxlen=100)
@@ -79,7 +82,7 @@ def __init__(self,
# checkpoint
self.checkpoint_modules = {}
- self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", 1000)
+ self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto")
self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False)
self.checkpoint_best_modules = {"timestep": 0, "reward": -2 ** 31, "saved": False, "modules": {}}
@@ -141,10 +144,10 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
:param trainer_cfg: Trainer configuration
:type trainer_cfg: dict, optional
"""
+ trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
# setup Weights & Biases
if self.cfg.get("experiment", {}).get("wandb", False):
# save experiment config
- trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
try:
models_cfg = {k: v.net._modules for (k, v) in self.models.items()}
except AttributeError:
@@ -161,6 +164,8 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
wandb.init(**wandb_kwargs)
# main entry to log data for consumption and visualization by TensorBoard
+ if self.write_interval == "auto":
+ self.write_interval = int(trainer_cfg.get("timesteps", 0) / 100)
if self.write_interval > 0:
self.writer = None
# tensorboard via torch SummaryWriter
@@ -203,6 +208,8 @@ def add_scalar(self, tag, value, step):
logger.warning("The current running process will be terminated.")
exit()
+ if self.checkpoint_interval == "auto":
+ self.checkpoint_interval = int(trainer_cfg.get("timesteps", 0) / 10)
if self.checkpoint_interval > 0:
os.makedirs(os.path.join(self.experiment_dir, "checkpoints"), exist_ok=True)
@@ -470,7 +477,9 @@ def post_interaction(self, timestep: int, timesteps: int) -> None:
self.checkpoint_best_modules["timestep"] = timestep
self.checkpoint_best_modules["reward"] = reward
self.checkpoint_best_modules["saved"] = False
- self.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()}
+ with jax.default_device(self.device):
+ self.checkpoint_best_modules["modules"] = \
+ {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()}
# write checkpoints
self.write_checkpoint(timestep, timesteps)
diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py
index 65c401e9..a597cc0e 100644
--- a/skrl/agents/jax/cem/cem.py
+++ b/skrl/agents/jax/cem/cem.py
@@ -38,9 +38,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -117,7 +117,8 @@ def __init__(self,
# set up optimizer and learning rate scheduler
if self.policy is not None:
- self.optimizer = Adam(model=self.policy, lr=self._learning_rate)
+ with jax.default_device(self.device):
+ self.optimizer = Adam(model=self.policy, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py
index feb0e12f..7c1a811e 100644
--- a/skrl/agents/jax/ddpg/ddpg.py
+++ b/skrl/agents/jax/ddpg/ddpg.py
@@ -48,9 +48,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -188,8 +188,9 @@ def __init__(self,
# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic is not None:
- self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
- self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ with jax.default_device(self.device):
+ self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
@@ -384,13 +385,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
- # sample a batch from memory
- sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
- self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
# gradient steps
for gradient_step in range(self._gradient_steps):
+ # sample a batch from memory
+ sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
+ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
+
sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py
index 47f7f53b..5bd17fc4 100644
--- a/skrl/agents/jax/dqn/ddqn.py
+++ b/skrl/agents/jax/dqn/ddqn.py
@@ -47,9 +47,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -161,7 +161,8 @@ def __init__(self,
# set up optimizer and learning rate scheduler
if self.q_network is not None:
- self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
+ with jax.default_device(self.device):
+ self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py
index dcec3140..6bcfac72 100644
--- a/skrl/agents/jax/dqn/dqn.py
+++ b/skrl/agents/jax/dqn/dqn.py
@@ -47,9 +47,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -158,7 +158,8 @@ def __init__(self,
# set up optimizer and learning rate scheduler
if self.q_network is not None:
- self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
+ with jax.default_device(self.device):
+ self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py
index 3437b049..f8a09dd1 100644
--- a/skrl/agents/jax/ppo/ppo.py
+++ b/skrl/agents/jax/ppo/ppo.py
@@ -53,9 +53,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -277,8 +277,9 @@ def __init__(self,
else:
self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
- self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
- self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
+ with jax.default_device(self.device):
+ self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
+ self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
self.checkpoint_modules["value_optimizer"] = self.value_optimizer
diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py
index 3e90d06d..deba2dcd 100644
--- a/skrl/agents/jax/rpo/rpo.py
+++ b/skrl/agents/jax/rpo/rpo.py
@@ -54,9 +54,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -281,8 +281,9 @@ def __init__(self,
else:
self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
- self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
- self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
+ with jax.default_device(self.device):
+ self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
+ self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
self.checkpoint_modules["value_optimizer"] = self.value_optimizer
diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py
index 75cecf11..fd157616 100644
--- a/skrl/agents/jax/sac/sac.py
+++ b/skrl/agents/jax/sac/sac.py
@@ -47,9 +47,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -219,16 +219,18 @@ class StateDict(flax.struct.PyTreeNode):
def value(self):
return self.state_dict.params["params"]
- self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient)
- self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate)
+ with jax.default_device(self.device):
+ self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient)
+ self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate)
self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer
# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None:
- self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
- self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
- self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ with jax.default_device(self.device):
+ self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py
index 1c2544d8..3248aebf 100644
--- a/skrl/agents/jax/td3/td3.py
+++ b/skrl/agents/jax/td3/td3.py
@@ -53,9 +53,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -219,9 +219,10 @@ def __init__(self,
# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None:
- self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
- self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
- self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ with jax.default_device(self.device):
+ self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
+ self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
@@ -422,12 +423,12 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
- # sample a batch from memory
- sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
- self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
# gradient steps
for gradient_step in range(self._gradient_steps):
+ # sample a batch from memory
+ sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
+ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py
index 97af7cb3..fd495516 100644
--- a/skrl/agents/torch/a2c/a2c.py
+++ b/skrl/agents/torch/a2c/a2c.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -45,9 +46,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
@@ -391,6 +400,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -407,7 +420,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py
index 9b24bc09..8cecc41e 100644
--- a/skrl/agents/torch/a2c/a2c_rnn.py
+++ b/skrl/agents/torch/a2c/a2c_rnn.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -45,9 +46,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
@@ -462,6 +471,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -478,7 +491,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py
index e9311dae..1a648c22 100644
--- a/skrl/agents/torch/amp/amp.py
+++ b/skrl/agents/torch/amp/amp.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -62,9 +63,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -147,6 +148,16 @@ def __init__(self,
self.checkpoint_modules["value"] = self.value
self.checkpoint_modules["discriminator"] = self.discriminator
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None:
+ self.value.broadcast_parameters()
+ if self.discriminator is not None:
+ self.discriminator.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -554,6 +565,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss + discriminator_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ self.value.reduce_parameters()
+ self.discriminator.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(),
self.value.parameters(),
@@ -571,7 +586,7 @@ def compute_gae(rewards: torch.Tensor,
if self._learning_rate_scheduler:
self.scheduler.step()
- # update AMP repaly buffer
+ # update AMP replay buffer
self.reply_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
# record data
diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py
index 237a0953..d23f27cc 100644
--- a/skrl/agents/torch/base.py
+++ b/skrl/agents/torch/base.py
@@ -11,7 +11,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter
-from skrl import logger
+from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -61,7 +61,7 @@ def __init__(self,
model.to(model.device)
self.tracking_data = collections.defaultdict(list)
- self.write_interval = self.cfg.get("experiment", {}).get("write_interval", 1000)
+ self.write_interval = self.cfg.get("experiment", {}).get("write_interval", "auto")
self._track_rewards = collections.deque(maxlen=100)
self._track_timesteps = collections.deque(maxlen=100)
@@ -72,7 +72,7 @@ def __init__(self,
# checkpoint
self.checkpoint_modules = {}
- self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", 1000)
+ self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto")
self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False)
self.checkpoint_best_modules = {"timestep": 0, "reward": -2 ** 31, "saved": False, "modules": {}}
@@ -85,6 +85,12 @@ def __init__(self,
experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__)
self.experiment_dir = os.path.join(directory, experiment_name)
+ # set up distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Distributed (rank: {config.torch.rank}, local rank: {config.torch.local_rank}, world size: {config.torch.world_size})")
+ torch.distributed.init_process_group("nccl", rank=config.torch.rank, world_size=config.torch.world_size)
+ torch.cuda.set_device(config.torch.local_rank)
+
def __str__(self) -> str:
"""Generate a representation of the agent as string
@@ -129,34 +135,45 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
- It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory
+ It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
:param trainer_cfg: Trainer configuration
:type trainer_cfg: dict, optional
"""
+ trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
+
+ # update agent configuration to avoid duplicated logging/checking in distributed runs
+ if config.torch.is_distributed and config.torch.rank:
+ self.write_interval = 0
+ self.checkpoint_interval = 0
+ # TODO: disable wandb
+
# setup Weights & Biases
if self.cfg.get("experiment", {}).get("wandb", False):
- # save experiment config
- trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
+ # save experiment configuration
try:
models_cfg = {k: v.net._modules for (k, v) in self.models.items()}
except AttributeError:
models_cfg = {k: v._modules for (k, v) in self.models.items()}
- config={**self.cfg, **trainer_cfg, **models_cfg}
+ wandb_config={**self.cfg, **trainer_cfg, **models_cfg}
# set default values
wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {}))
wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1])
wandb_kwargs.setdefault("sync_tensorboard", True)
wandb_kwargs.setdefault("config", {})
- wandb_kwargs["config"].update(config)
+ wandb_kwargs["config"].update(wandb_config)
# init Weights & Biases
import wandb
wandb.init(**wandb_kwargs)
# main entry to log data for consumption and visualization by TensorBoard
+ if self.write_interval == "auto":
+ self.write_interval = int(trainer_cfg.get("timesteps", 0) / 100)
if self.write_interval > 0:
self.writer = SummaryWriter(log_dir=self.experiment_dir)
+ if self.checkpoint_interval == "auto":
+ self.checkpoint_interval = int(trainer_cfg.get("timesteps", 0) / 10)
if self.checkpoint_interval > 0:
os.makedirs(os.path.join(self.experiment_dir, "checkpoints"), exist_ok=True)
@@ -382,7 +399,7 @@ def migrate(self,
name_map: Mapping[str, Mapping[str, str]] = {},
auto_mapping: bool = True,
verbose: bool = False) -> bool:
- """Migrate the specified extrernal checkpoint to the current agent
+ """Migrate the specified external checkpoint to the current agent
The final storage device is determined by the constructor of the agent.
Only files generated by the *rl_games* library are supported at the moment
diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py
index a99eff5a..06f5ba24 100644
--- a/skrl/agents/torch/cem/cem.py
+++ b/skrl/agents/torch/cem/cem.py
@@ -35,9 +35,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py
index a5270909..88f6ccc9 100644
--- a/skrl/agents/torch/ddpg/ddpg.py
+++ b/skrl/agents/torch/ddpg/ddpg.py
@@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -46,9 +47,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -109,6 +110,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic is not None:
+ self.critic.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -315,13 +324,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
- # sample a batch from memory
- sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
- self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
# gradient steps
for gradient_step in range(self._gradient_steps):
+ # sample a batch from memory
+ sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
+ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
+
sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
@@ -340,6 +350,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -353,6 +365,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py
index e1a8142e..436184c1 100644
--- a/skrl/agents/torch/ddpg/ddpg_rnn.py
+++ b/skrl/agents/torch/ddpg/ddpg_rnn.py
@@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -46,9 +47,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -109,6 +110,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic is not None:
+ self.critic.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -382,6 +391,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -395,6 +406,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py
index 84352027..8d181ac8 100644
--- a/skrl/agents/torch/dqn/ddqn.py
+++ b/skrl/agents/torch/dqn/ddqn.py
@@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -45,9 +46,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -104,6 +105,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.q_network is not None:
+ self.q_network.broadcast_parameters()
+
if self.target_q_network is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_q_network.freeze_parameters(True)
@@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimize Q-network
self.optimizer.zero_grad()
q_network_loss.backward()
+ if config.torch.is_distributed:
+ self.q_network.reduce_parameters()
self.optimizer.step()
# update target network
diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py
index 4c485524..c7f4f709 100644
--- a/skrl/agents/torch/dqn/dqn.py
+++ b/skrl/agents/torch/dqn/dqn.py
@@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -45,9 +46,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -104,6 +105,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.q_network is not None:
+ self.q_network.broadcast_parameters()
+
if self.target_q_network is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_q_network.freeze_parameters(True)
@@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimize Q-network
self.optimizer.zero_grad()
q_network_loss.backward()
+ if config.torch.is_distributed:
+ self.q_network.reduce_parameters()
self.optimizer.step()
# update target network
diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py
index 8c2315bd..cb68ac2a 100644
--- a/skrl/agents/torch/ppo/ppo.py
+++ b/skrl/agents/torch/ppo/ppo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -52,9 +53,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -111,6 +112,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -418,6 +427,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -434,7 +447,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py
index ccabafca..1ec4e7ca 100644
--- a/skrl/agents/torch/ppo/ppo_rnn.py
+++ b/skrl/agents/torch/ppo/ppo_rnn.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -52,9 +53,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -111,6 +112,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -490,6 +499,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -506,7 +519,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py
index 16212d8f..90da91cc 100644
--- a/skrl/agents/torch/q_learning/q_learning.py
+++ b/skrl/agents/torch/q_learning/q_learning.py
@@ -25,9 +25,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py
index 5929f54e..dfb6fa9f 100644
--- a/skrl/agents/torch/rpo/rpo.py
+++ b/skrl/agents/torch/rpo/rpo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -53,9 +54,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -420,6 +429,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -436,7 +449,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py
index 382d1efb..94c7c27a 100644
--- a/skrl/agents/torch/rpo/rpo_rnn.py
+++ b/skrl/agents/torch/rpo/rpo_rnn.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -53,9 +54,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -492,6 +501,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -508,7 +521,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py
index 22468d80..dc8678a6 100644
--- a/skrl/agents/torch/sac/sac.py
+++ b/skrl/agents/torch/sac/sac.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -46,9 +47,9 @@
"experiment": {
"base_directory": "", # base directory for the experiment
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -111,6 +112,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
@@ -325,6 +336,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -339,6 +353,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py
index 755cbeab..6553d7aa 100644
--- a/skrl/agents/torch/sac/sac_rnn.py
+++ b/skrl/agents/torch/sac/sac_rnn.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -46,9 +47,9 @@
"experiment": {
"base_directory": "", # base directory for the experiment
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -111,6 +112,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
@@ -367,6 +378,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -381,6 +395,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py
index 4cc14c34..5abc260a 100644
--- a/skrl/agents/torch/sarsa/sarsa.py
+++ b/skrl/agents/torch/sarsa/sarsa.py
@@ -25,9 +25,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py
index 86275243..abbcc467 100644
--- a/skrl/agents/torch/td3/td3.py
+++ b/skrl/agents/torch/td3/td3.py
@@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from skrl import logger
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -52,9 +52,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -119,6 +119,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -336,13 +346,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
- # sample a batch from memory
- sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
- self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
# gradient steps
for gradient_step in range(self._gradient_steps):
+ # sample a batch from memory
+ sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
+ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
+
sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
@@ -371,6 +382,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -388,6 +402,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py
index fdd619d8..abd6a922 100644
--- a/skrl/agents/torch/td3/td3_rnn.py
+++ b/skrl/agents/torch/td3/td3_rnn.py
@@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from skrl import logger
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -52,9 +52,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -119,6 +119,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -413,6 +423,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -430,6 +443,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py
index 2c00b69b..96565bef 100644
--- a/skrl/agents/torch/trpo/trpo.py
+++ b/skrl/agents/torch/trpo/trpo.py
@@ -9,6 +9,7 @@
import torch.nn.functional as F
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -51,9 +52,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -521,6 +530,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
if restore_policy_flag:
self.policy.update_parameters(self.backup_policy)
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+
# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches)
@@ -542,6 +554,8 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
# optimization step (value)
self.value_optimizer.zero_grad()
value_loss.backward()
+ if config.torch.is_distributed:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
self.value_optimizer.step()
diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py
index 58b187e5..4b3ad05a 100644
--- a/skrl/agents/torch/trpo/trpo_rnn.py
+++ b/skrl/agents/torch/trpo/trpo_rnn.py
@@ -9,6 +9,7 @@
import torch.nn.functional as F
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -51,9 +52,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
- "write_interval": 250, # TensorBoard writing interval (timesteps)
+ "write_interval": "auto", # TensorBoard writing interval (timesteps)
- "checkpoint_interval": 1000, # interval for checkpoints (timesteps)
+ "checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -591,6 +600,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
if restore_policy_flag:
self.policy.update_parameters(self.backup_policy)
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+
# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length)
@@ -622,6 +634,8 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
# optimization step (value)
self.value_optimizer.zero_grad()
value_loss.backward()
+ if config.torch.is_distributed:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
self.value_optimizer.step()
diff --git a/skrl/envs/jax.py b/skrl/envs/jax.py
index 540fbdff..2f87c756 100644
--- a/skrl/envs/jax.py
+++ b/skrl/envs/jax.py
@@ -8,10 +8,10 @@
from skrl.envs.loaders.jax import (
load_bidexhands_env,
- load_isaac_orbit_env,
load_isaacgym_env_preview2,
load_isaacgym_env_preview3,
load_isaacgym_env_preview4,
+ load_isaaclab_env,
load_omniverse_isaacgym_env
)
from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper, wrap_env
diff --git a/skrl/envs/loaders/jax/__init__.py b/skrl/envs/loaders/jax/__init__.py
index 43690f93..a200ec9f 100644
--- a/skrl/envs/loaders/jax/__init__.py
+++ b/skrl/envs/loaders/jax/__init__.py
@@ -1,8 +1,8 @@
from skrl.envs.loaders.jax.bidexhands_envs import load_bidexhands_env
-from skrl.envs.loaders.jax.isaac_orbit_envs import load_isaac_orbit_env
from skrl.envs.loaders.jax.isaacgym_envs import (
load_isaacgym_env_preview2,
load_isaacgym_env_preview3,
load_isaacgym_env_preview4
)
+from skrl.envs.loaders.jax.isaaclab_envs import load_isaaclab_env
from skrl.envs.loaders.jax.omniverse_isaacgym_envs import load_omniverse_isaacgym_env
diff --git a/skrl/envs/loaders/jax/isaac_orbit_envs.py b/skrl/envs/loaders/jax/isaac_orbit_envs.py
deleted file mode 100644
index 9a7f59b2..00000000
--- a/skrl/envs/loaders/jax/isaac_orbit_envs.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# since Isaac Orbit environments are implemented on top of PyTorch, the loader is the same
-
-from skrl.envs.loaders.torch import load_isaac_orbit_env
diff --git a/skrl/envs/loaders/jax/isaaclab_envs.py b/skrl/envs/loaders/jax/isaaclab_envs.py
new file mode 100644
index 00000000..33b6c9a4
--- /dev/null
+++ b/skrl/envs/loaders/jax/isaaclab_envs.py
@@ -0,0 +1,3 @@
+# since Isaac Lab environments are implemented on top of PyTorch, the loader is the same
+
+from skrl.envs.loaders.torch import load_isaaclab_env
diff --git a/skrl/envs/loaders/torch/__init__.py b/skrl/envs/loaders/torch/__init__.py
index 417c9958..a7a97e9a 100644
--- a/skrl/envs/loaders/torch/__init__.py
+++ b/skrl/envs/loaders/torch/__init__.py
@@ -1,8 +1,8 @@
from skrl.envs.loaders.torch.bidexhands_envs import load_bidexhands_env
-from skrl.envs.loaders.torch.isaac_orbit_envs import load_isaac_orbit_env
from skrl.envs.loaders.torch.isaacgym_envs import (
load_isaacgym_env_preview2,
load_isaacgym_env_preview3,
load_isaacgym_env_preview4
)
+from skrl.envs.loaders.torch.isaaclab_envs import load_isaaclab_env
from skrl.envs.loaders.torch.omniverse_isaacgym_envs import load_omniverse_isaacgym_env
diff --git a/skrl/envs/loaders/torch/isaac_orbit_envs.py b/skrl/envs/loaders/torch/isaaclab_envs.py
similarity index 73%
rename from skrl/envs/loaders/torch/isaac_orbit_envs.py
rename to skrl/envs/loaders/torch/isaaclab_envs.py
index 0f86b12f..c4516595 100644
--- a/skrl/envs/loaders/torch/isaac_orbit_envs.py
+++ b/skrl/envs/loaders/torch/isaaclab_envs.py
@@ -6,7 +6,7 @@
from skrl import logger
-__all__ = ["load_isaac_orbit_env"]
+__all__ = ["load_isaaclab_env"]
def _print_cfg(d, indent=0) -> None:
@@ -24,16 +24,16 @@ def _print_cfg(d, indent=0) -> None:
print(" | " * indent + f" |-- {key}: {value}")
-def load_isaac_orbit_env(task_name: str = "",
- num_envs: Optional[int] = None,
- headless: Optional[bool] = None,
- cli_args: Sequence[str] = [],
- show_cfg: bool = True):
- """Load an Isaac Orbit environment
+def load_isaaclab_env(task_name: str = "",
+ num_envs: Optional[int] = None,
+ headless: Optional[bool] = None,
+ cli_args: Sequence[str] = [],
+ show_cfg: bool = True):
+ """Load an Isaac Lab environment
- Isaac Orbit: https://isaac-orbit.github.io/orbit/index.html
+ Isaac Lab: https://isaac-sim.github.io/IsaacLab
- This function includes the definition and parsing of command line arguments used by Isaac Orbit:
+ This function includes the definition and parsing of command line arguments used by Isaac Lab:
- ``--headless``: Force display off at all times
- ``--cpu``: Use CPU pipeline
@@ -53,19 +53,19 @@ def load_isaac_orbit_env(task_name: str = "",
If not specified, the default task configuration is used.
Command line argument has priority over function parameter if both are specified
:type headless: bool, optional
- :param cli_args: Isaac Orbit configuration and command line arguments (default: ``[]``)
+ :param cli_args: Isaac Lab configuration and command line arguments (default: ``[]``)
:type cli_args: list of str, optional
:param show_cfg: Whether to print the configuration (default: ``True``)
:type show_cfg: bool, optional
:raises ValueError: The task name has not been defined, neither by the function parameter nor by the command line arguments
- :return: Isaac Orbit environment
+ :return: Isaac Lab environment
:rtype: gym.Env
"""
import argparse
import atexit
- import gym
+ import gymnasium as gym
# check task from command line arguments
defined = False
@@ -121,46 +121,42 @@ def load_isaac_orbit_env(task_name: str = "",
sys.argv += cli_args
# parse arguments
- parser = argparse.ArgumentParser("Welcome to Orbit: Omniverse Robotics Environments!")
- parser.add_argument("--headless", action="store_true", default=False, help="Force display off at all times.")
+ parser = argparse.ArgumentParser("Isaac Lab: Omniverse Robotics Environments!")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
- args = parser.parse_args()
-
- # load the most efficient kit configuration in headless mode
- if args.headless:
- app_experience = f"{os.environ['EXP_PATH']}/omni.isaac.sim.python.gym.headless.kit"
- else:
- app_experience = f"{os.environ['EXP_PATH']}/omni.isaac.sim.python.kit"
+ parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
+ parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
+ parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
+ parser.add_argument("--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations.")
+ parser.add_argument("--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes.")
- # launch the simulator
- from omni.isaac.kit import SimulationApp # type: ignore
+ # launch the simulation app
+ from omni.isaac.lab.app import AppLauncher
- config = {"headless": args.headless}
- simulation_app = SimulationApp(config, experience=app_experience)
+ AppLauncher.add_app_launcher_args(parser)
+ args = parser.parse_args()
+ app_launcher = AppLauncher(args)
@atexit.register
def close_the_simulator():
- simulation_app.close()
+ app_launcher.app.close()
- # import orbit extensions
- import omni.isaac.contrib_envs # type: ignore
- import omni.isaac.orbit_envs # type: ignore
- from omni.isaac.orbit_envs.utils import parse_env_cfg # type: ignore
+ import omni.isaac.lab_tasks # type: ignore
+ from omni.isaac.lab_tasks.utils import parse_env_cfg # type: ignore
- cfg = parse_env_cfg(args.task, use_gpu=not args.cpu, num_envs=args.num_envs)
+ cfg = parse_env_cfg(args.task, use_gpu=not args.cpu, num_envs=args.num_envs, use_fabric=not args.disable_fabric)
# print config
if show_cfg:
- print(f"\nIsaac Orbit environment ({args.task})")
+ print(f"\nIsaac Lab environment ({args.task})")
try:
_print_cfg(cfg)
except AttributeError as e:
pass
# load environment
- env = gym.make(args.task, cfg=cfg, headless=args.headless)
+ env = gym.make(args.task, cfg=cfg, render_mode="rgb_array" if args.video else None)
return env
diff --git a/skrl/envs/torch.py b/skrl/envs/torch.py
index b820c297..d5922a7f 100644
--- a/skrl/envs/torch.py
+++ b/skrl/envs/torch.py
@@ -8,10 +8,10 @@
from skrl.envs.loaders.torch import (
load_bidexhands_env,
- load_isaac_orbit_env,
load_isaacgym_env_preview2,
load_isaacgym_env_preview3,
load_isaacgym_env_preview4,
+ load_isaaclab_env,
load_omniverse_isaacgym_env
)
from skrl.envs.wrappers.torch import MultiAgentEnvWrapper, Wrapper, wrap_env
diff --git a/skrl/envs/wrappers/jax/__init__.py b/skrl/envs/wrappers/jax/__init__.py
index d80735b7..fd0fdf6a 100644
--- a/skrl/envs/wrappers/jax/__init__.py
+++ b/skrl/envs/wrappers/jax/__init__.py
@@ -8,8 +8,8 @@
from skrl.envs.wrappers.jax.bidexhands_envs import BiDexHandsWrapper
from skrl.envs.wrappers.jax.gym_envs import GymWrapper
from skrl.envs.wrappers.jax.gymnasium_envs import GymnasiumWrapper
-from skrl.envs.wrappers.jax.isaac_orbit_envs import IsaacOrbitWrapper
from skrl.envs.wrappers.jax.isaacgym_envs import IsaacGymPreview2Wrapper, IsaacGymPreview3Wrapper
+from skrl.envs.wrappers.jax.isaaclab_envs import IsaacLabWrapper
from skrl.envs.wrappers.jax.omniverse_isaacgym_envs import OmniverseIsaacGymWrapper
from skrl.envs.wrappers.jax.pettingzoo_envs import PettingZooWrapper
@@ -52,7 +52,7 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
+--------------------+-------------------------+
|Omniverse Isaac Gym |``"omniverse-isaacgym"`` |
+--------------------+-------------------------+
- |Isaac Sim (orbit) |``"isaac-orbit"`` |
+ |Isaac Lab |``"isaaclab"`` |
+--------------------+-------------------------+
:type wrapper: str, optional
:param verbose: Whether to print the wrapper type (default: ``True``)
@@ -63,50 +63,44 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
:return: Wrapped environment
:rtype: Wrapper or MultiAgentEnvWrapper
"""
- if verbose:
- logger.info("Environment class: {}".format(", ".join([str(base).replace("", "") \
- for base in env.__class__.__bases__])))
- if wrapper == "auto":
- base_classes = [str(base) for base in env.__class__.__bases__]
- if "" in base_classes or \
- "" in base_classes:
- if verbose:
- logger.info("Environment wrapper: Omniverse Isaac Gym")
- return OmniverseIsaacGymWrapper(env)
- elif isinstance(env, gym.core.Env) or isinstance(env, gym.core.Wrapper):
- # isaac-orbit
- if hasattr(env, "sim") and hasattr(env, "env_ns"):
- if verbose:
- logger.info("Environment wrapper: Isaac Orbit")
- return IsaacOrbitWrapper(env)
- # gym
- if verbose:
- logger.info("Environment wrapper: Gym")
- return GymWrapper(env)
- elif isinstance(env, gymnasium.core.Env) or isinstance(env, gymnasium.core.Wrapper):
- if verbose:
- logger.info("Environment wrapper: Gymnasium")
- return GymnasiumWrapper(env)
- elif "" in base_classes:
- if verbose:
- logger.info("Environment wrapper: DeepMind")
- return DeepMindWrapper(env)
- elif "" in base_classes:
- if verbose:
- logger.info("Environment wrapper: Isaac Gym (preview 2)")
- return IsaacGymPreview2Wrapper(env)
+ def _get_wrapper_name(env, verbose):
+ def _in(value, container):
+ for item in container:
+ if value in item:
+ return True
+ return False
+
+ base_classes = [str(base).replace("", "") for base in env.__class__.__bases__]
+ try:
+ base_classes += [str(base).replace("", "") for base in env.unwrapped.__class__.__bases__]
+ except:
+ pass
+ base_classes = sorted(list(set(base_classes)))
if verbose:
- logger.info("Environment wrapper: Isaac Gym (preview 3/4)")
- return IsaacGymPreview3Wrapper(env) # preview 4 is the same as 3
- elif wrapper == "gym":
+ logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})")
+
+ if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes):
+ return "isaaclab"
+ elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes):
+ return "omniverse-isaacgym"
+ elif _in("rlgpu.tasks.base.vec_task.VecTask", base_classes):
+ return "isaacgym-preview2"
+ elif _in("robosuite.environments.", base_classes):
+ return "robosuite"
+ elif _in("dm_env._environment.Environment.", base_classes):
+ return "dm"
+ elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes):
+ return "pettingzoo"
+ elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes):
+ return "gymnasium"
+ elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes):
+ return "gym"
+ return base_classes
+
+ if wrapper == "auto":
+ wrapper = _get_wrapper_name(env, verbose)
+
+ if wrapper == "gym":
if verbose:
logger.info("Environment wrapper: Gym")
return GymWrapper(env)
@@ -146,9 +140,9 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
if verbose:
logger.info("Environment wrapper: Omniverse Isaac Gym")
return OmniverseIsaacGymWrapper(env)
- elif wrapper == "isaac-orbit":
+ elif wrapper == "isaaclab" or wrapper == "isaac-orbit":
if verbose:
- logger.info("Environment wrapper: Isaac Orbit")
- return IsaacOrbitWrapper(env)
+ logger.info("Environment wrapper: Isaac Lab")
+ return IsaacLabWrapper(env)
else:
raise ValueError(f"Unknown wrapper type: {wrapper}")
diff --git a/skrl/envs/wrappers/jax/base.py b/skrl/envs/wrappers/jax/base.py
index 0be8b742..9e8bc467 100644
--- a/skrl/envs/wrappers/jax/base.py
+++ b/skrl/envs/wrappers/jax/base.py
@@ -20,12 +20,19 @@ def __init__(self, env: Any) -> None:
self._env = env
# device (faster than @property)
- self.device = jax.devices()[0]
+ self.device = None
if hasattr(self._env, "device"):
- try:
- self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0]
- except RuntimeError:
- pass
+ if type(self._env.device) == str:
+ device_type, device_index = f"{self._env.device}:0".split(':')[:2]
+ try:
+ self.device = jax.devices(device_type)[int(device_index)]
+ except (RuntimeError, IndexError):
+ self.device = None
+ else:
+ self.device = self._env.device
+ if self.device is None:
+ self.device = jax.devices()[0]
+
# spaces
try:
self._action_space = self._env.single_action_space
@@ -135,12 +142,18 @@ def __init__(self, env: Any) -> None:
self._env = env
# device (faster than @property)
- self.device = jax.devices()[0]
+ self.device = None
if hasattr(self._env, "device"):
- try:
- self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0]
- except RuntimeError:
- pass
+ if type(self._env.device) == str:
+ device_type, device_index = f"{self._env.device}:0".split(':')[:2]
+ try:
+ self.device = jax.devices(device_type)[int(device_index)]
+ except (RuntimeError, IndexError):
+ self.device = None
+ else:
+ self.device = self._env.device
+ if self.device is None:
+ self.device = jax.devices()[0]
self.possible_agents = []
diff --git a/skrl/envs/wrappers/jax/isaac_orbit_envs.py b/skrl/envs/wrappers/jax/isaaclab_envs.py
similarity index 89%
rename from skrl/envs/wrappers/jax/isaac_orbit_envs.py
rename to skrl/envs/wrappers/jax/isaaclab_envs.py
index c1e897c9..0d5f13bc 100644
--- a/skrl/envs/wrappers/jax/isaac_orbit_envs.py
+++ b/skrl/envs/wrappers/jax/isaaclab_envs.py
@@ -19,7 +19,7 @@
# jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided.
_CPU = jax.devices()[0].device_kind.lower() == "cpu"
if _CPU:
- logger.warning("Isaac Orbit runs on GPU, but there is no GPU backend for JAX. JAX operations will run on CPU.")
+ logger.warning("Isaac Lab runs on GPU, but there is no GPU backend for JAX. JAX operations will run on CPU.")
def _jax2torch(array, device, from_jax=True):
if from_jax:
@@ -32,18 +32,20 @@ def _torch2jax(tensor, to_jax=True):
return tensor.cpu().numpy()
-class IsaacOrbitWrapper(Wrapper):
+class IsaacLabWrapper(Wrapper):
def __init__(self, env: Any) -> None:
- """Isaac Orbit environment wrapper
+ """Isaac Lab environment wrapper
:param env: The environment to wrap
- :type env: Any supported Isaac Orbit environment
+ :type env: Any supported Isaac Lab environment
"""
super().__init__(env)
self._reset_once = True
self._obs_dict = None
+ self._observation_space = self._observation_space["policy"]
+
def step(self, actions: Union[np.ndarray, jax.Array]) -> \
Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array],
Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]:
diff --git a/skrl/envs/wrappers/torch/__init__.py b/skrl/envs/wrappers/torch/__init__.py
index ab2d8c3d..1fc9898b 100644
--- a/skrl/envs/wrappers/torch/__init__.py
+++ b/skrl/envs/wrappers/torch/__init__.py
@@ -9,8 +9,8 @@
from skrl.envs.wrappers.torch.deepmind_envs import DeepMindWrapper
from skrl.envs.wrappers.torch.gym_envs import GymWrapper
from skrl.envs.wrappers.torch.gymnasium_envs import GymnasiumWrapper
-from skrl.envs.wrappers.torch.isaac_orbit_envs import IsaacOrbitWrapper
from skrl.envs.wrappers.torch.isaacgym_envs import IsaacGymPreview2Wrapper, IsaacGymPreview3Wrapper
+from skrl.envs.wrappers.torch.isaaclab_envs import IsaacLabWrapper
from skrl.envs.wrappers.torch.omniverse_isaacgym_envs import OmniverseIsaacGymWrapper
from skrl.envs.wrappers.torch.pettingzoo_envs import PettingZooWrapper
from skrl.envs.wrappers.torch.robosuite_envs import RobosuiteWrapper
@@ -58,7 +58,7 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
+--------------------+-------------------------+
|Omniverse Isaac Gym |``"omniverse-isaacgym"`` |
+--------------------+-------------------------+
- |Isaac Sim (orbit) |``"isaac-orbit"`` |
+ |Isaac Lab |``"isaaclab"`` |
+--------------------+-------------------------+
:type wrapper: str, optional
:param verbose: Whether to print the wrapper type (default: ``True``)
@@ -69,50 +69,44 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
:return: Wrapped environment
:rtype: Wrapper or MultiAgentEnvWrapper
"""
- if verbose:
- logger.info("Environment class: {}".format(", ".join([str(base).replace("", "") \
- for base in env.__class__.__bases__])))
- if wrapper == "auto":
- base_classes = [str(base) for base in env.__class__.__bases__]
- if "" in base_classes or \
- "" in base_classes:
- if verbose:
- logger.info("Environment wrapper: Omniverse Isaac Gym")
- return OmniverseIsaacGymWrapper(env)
- elif isinstance(env, gym.core.Env) or isinstance(env, gym.core.Wrapper):
- # isaac-orbit
- if hasattr(env, "sim") and hasattr(env, "env_ns"):
- if verbose:
- logger.info("Environment wrapper: Isaac Orbit")
- return IsaacOrbitWrapper(env)
- # gym
- if verbose:
- logger.info("Environment wrapper: Gym")
- return GymWrapper(env)
- elif isinstance(env, gymnasium.core.Env) or isinstance(env, gymnasium.core.Wrapper):
- if verbose:
- logger.info("Environment wrapper: Gymnasium")
- return GymnasiumWrapper(env)
- elif "" in base_classes:
- if verbose:
- logger.info("Environment wrapper: DeepMind")
- return DeepMindWrapper(env)
- elif "" in base_classes:
- if verbose:
- logger.info("Environment wrapper: Isaac Gym (preview 2)")
- return IsaacGymPreview2Wrapper(env)
+ def _get_wrapper_name(env, verbose):
+ def _in(value, container):
+ for item in container:
+ if value in item:
+ return True
+ return False
+
+ base_classes = [str(base).replace("", "") for base in env.__class__.__bases__]
+ try:
+ base_classes += [str(base).replace("", "") for base in env.unwrapped.__class__.__bases__]
+ except:
+ pass
+ base_classes = sorted(list(set(base_classes)))
if verbose:
- logger.info("Environment wrapper: Isaac Gym (preview 3/4)")
- return IsaacGymPreview3Wrapper(env) # preview 4 is the same as 3
- elif wrapper == "gym":
+ logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})")
+
+ if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes):
+ return "isaaclab"
+ elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes):
+ return "omniverse-isaacgym"
+ elif _in("rlgpu.tasks.base.vec_task.VecTask", base_classes):
+ return "isaacgym-preview2"
+ elif _in("robosuite.environments.", base_classes):
+ return "robosuite"
+ elif _in("dm_env._environment.Environment.", base_classes):
+ return "dm"
+ elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes):
+ return "pettingzoo"
+ elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes):
+ return "gymnasium"
+ elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes):
+ return "gym"
+ return base_classes
+
+ if wrapper == "auto":
+ wrapper = _get_wrapper_name(env, verbose)
+
+ if wrapper == "gym":
if verbose:
logger.info("Environment wrapper: Gym")
return GymWrapper(env)
@@ -152,9 +146,9 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
if verbose:
logger.info("Environment wrapper: Omniverse Isaac Gym")
return OmniverseIsaacGymWrapper(env)
- elif wrapper == "isaac-orbit":
+ elif wrapper == "isaaclab" or wrapper == "isaac-orbit":
if verbose:
- logger.info("Environment wrapper: Isaac Orbit")
- return IsaacOrbitWrapper(env)
+ logger.info("Environment wrapper: Isaac Lab")
+ return IsaacLabWrapper(env)
else:
raise ValueError(f"Unknown wrapper type: {wrapper}")
diff --git a/skrl/envs/wrappers/torch/isaac_orbit_envs.py b/skrl/envs/wrappers/torch/isaaclab_envs.py
similarity index 87%
rename from skrl/envs/wrappers/torch/isaac_orbit_envs.py
rename to skrl/envs/wrappers/torch/isaaclab_envs.py
index c9670ce1..9f849764 100644
--- a/skrl/envs/wrappers/torch/isaac_orbit_envs.py
+++ b/skrl/envs/wrappers/torch/isaaclab_envs.py
@@ -5,18 +5,20 @@
from skrl.envs.wrappers.torch.base import Wrapper
-class IsaacOrbitWrapper(Wrapper):
+class IsaacLabWrapper(Wrapper):
def __init__(self, env: Any) -> None:
- """Isaac Orbit environment wrapper
+ """Isaac Lab environment wrapper
:param env: The environment to wrap
- :type env: Any supported Isaac Orbit environment
+ :type env: Any supported Isaac Lab environment
"""
super().__init__(env)
self._reset_once = True
self._obs_dict = None
+ self._observation_space = self._observation_space["policy"]
+
def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Perform a step in the environment
diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py
index 41cea1b1..c87af33b 100644
--- a/skrl/memories/jax/base.py
+++ b/skrl/memories/jax/base.py
@@ -70,7 +70,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
- self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
+ self.device = device
+ if type(device) == str:
+ device_type, device_index = f"{device}:0".split(':')[:2]
+ self.device = jax.devices(device_type)[int(device_index)]
# internal variables
self.filled = False
@@ -236,18 +239,21 @@ def create_tensor(self,
view_shape = (-1, *size) if keep_dimensions else (-1, size)
# create tensor (_tensor_) and add it to the internal storage
if self._jax:
- setattr(self, f"_tensor_{name}", jnp.zeros(tensor_shape, dtype=dtype))
+ with jax.default_device(self.device):
+ setattr(self, f"_tensor_{name}", jnp.zeros(tensor_shape, dtype=dtype))
else:
setattr(self, f"_tensor_{name}", np.zeros(tensor_shape, dtype=dtype))
# update internal variables
self.tensors[name] = getattr(self, f"_tensor_{name}")
- self.tensors_view[name] = self.tensors[name].reshape(*view_shape)
+ with jax.default_device(self.device):
+ self.tensors_view[name] = self.tensors[name].reshape(*view_shape)
self.tensors_keep_dimensions[name] = keep_dimensions
# fill the tensors (float tensors) with NaN
for name, tensor in self.tensors.items():
if tensor.dtype == np.float32 or tensor.dtype == np.float64:
if self._jax:
- self.tensors[name] = _copyto(self.tensors[name], float("nan"))
+ with jax.default_device(self.device):
+ self.tensors[name] = _copyto(self.tensors[name], float("nan"))
else:
self.tensors[name].fill(float("nan"))
# check views
diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py
index 30db0e83..9a3738a5 100644
--- a/skrl/models/jax/base.py
+++ b/skrl/models/jax/base.py
@@ -79,7 +79,10 @@ def __call__(self, inputs, role):
if device is None:
self.device = jax.devices()[0]
else:
- self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
+ self.device = device
+ if type(device) == str:
+ device_type, device_index = f"{device}:0".split(':')[:2]
+ self.device = jax.devices(device_type)[int(device_index)]
self.observation_space = observation_space
self.action_space = action_space
@@ -119,7 +122,8 @@ def init_state_dict(self,
if isinstance(inputs["states"], (int, np.int32, np.int64)):
inputs["states"] = np.array(inputs["states"]).reshape(-1,1)
# init internal state dict
- self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role))
+ with jax.default_device(self.device):
+ self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role))
def _get_space_size(self,
space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
diff --git a/skrl/models/jax/categorical.py b/skrl/models/jax/categorical.py
index 95b0331a..c9aa11a2 100644
--- a/skrl/models/jax/categorical.py
+++ b/skrl/models/jax/categorical.py
@@ -122,9 +122,10 @@ def act(self,
>>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
(4096, 1) (4096, 1) (4096, 2)
"""
- self._i += 1
- subkey = jax.random.fold_in(self._key, self._i)
- inputs["key"] = subkey
+ with jax.default_device(self.device):
+ self._i += 1
+ subkey = jax.random.fold_in(self._key, self._i)
+ inputs["key"] = subkey
# map from states/observations to normalized probabilities or unnormalized log probabilities
net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
diff --git a/skrl/models/jax/gaussian.py b/skrl/models/jax/gaussian.py
index fce2d90c..53245372 100644
--- a/skrl/models/jax/gaussian.py
+++ b/skrl/models/jax/gaussian.py
@@ -173,9 +173,10 @@ def act(self,
>>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape)
(4096, 8) (4096, 1) (4096, 8)
"""
- self._i += 1
- subkey = jax.random.fold_in(self._key, self._i)
- inputs["key"] = subkey
+ with jax.default_device(self.device):
+ self._i += 1
+ subkey = jax.random.fold_in(self._key, self._i)
+ inputs["key"] = subkey
# map from states/observations to mean actions and log standard deviations
mean_actions, log_std, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
diff --git a/skrl/models/jax/multicategorical.py b/skrl/models/jax/multicategorical.py
index 9ad1bb3e..b3a91e84 100644
--- a/skrl/models/jax/multicategorical.py
+++ b/skrl/models/jax/multicategorical.py
@@ -136,9 +136,10 @@ def act(self,
>>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
(4096, 2) (4096, 1) (4096, 5)
"""
- self._i += 1
- subkey = jax.random.fold_in(self._key, self._i)
- inputs["key"] = subkey
+ with jax.default_device(self.device):
+ self._i += 1
+ subkey = jax.random.fold_in(self._key, self._i)
+ inputs["key"] = subkey
# map from states/observations to normalized probabilities or unnormalized log probabilities
net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py
index 757a8ba2..bf90ba8a 100644
--- a/skrl/models/torch/base.py
+++ b/skrl/models/torch/base.py
@@ -7,7 +7,7 @@
import numpy as np
import torch
-from skrl import logger
+from skrl import config, logger
class Model(torch.nn.Module):
@@ -743,3 +743,48 @@ def update_parameters(self, model: torch.nn.Module, polyak: float = 1) -> None:
for parameters, model_parameters in zip(self.parameters(), model.parameters()):
parameters.data.mul_(1 - polyak)
parameters.data.add_(polyak * model_parameters.data)
+
+ def broadcast_parameters(self, rank: int = 0):
+ """Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs
+
+ After calling this method, the distributed model will contain the broadcasted parameters from ``rank``
+
+ :param rank: Worker/process rank from which to broadcast model parameters (default: ``0``)
+ :type rank: int
+
+ Example::
+
+ # broadcast model parameter from worker/process with rank 1
+ >>> if config.torch.is_distributed:
+ ... model.update_parameters(source_model, rank=1)
+ """
+ object_list = [self.state_dict()]
+ torch.distributed.broadcast_object_list(object_list, rank)
+ self.load_state_dict(object_list[0])
+
+ def reduce_parameters(self):
+ """Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes)
+
+ After calling this method, the distributed model parameters will be bitwise identical for all workers/processes
+
+ Example::
+
+ # reduce model parameter across all workers/processes
+ >>> if config.torch.is_distributed:
+ ... model.reduce_parameters()
+ """
+ # batch all_reduce ops: https://github.com/entity-neural-network/incubator/pull/220
+ gradients = []
+ for parameters in self.parameters():
+ if parameters.grad is not None:
+ gradients.append(parameters.grad.view(-1))
+ gradients = torch.cat(gradients)
+
+ torch.distributed.all_reduce(gradients, op=torch.distributed.ReduceOp.SUM)
+
+ offset = 0
+ for parameters in self.parameters():
+ if parameters.grad is not None:
+ parameters.grad.data.copy_(gradients[offset:offset + parameters.numel()] \
+ .view_as(parameters.grad.data) / config.torch.world_size)
+ offset += parameters.numel()
diff --git a/skrl/multi_agents/jax/base.py b/skrl/multi_agents/jax/base.py
index 77619b81..e177610d 100644
--- a/skrl/multi_agents/jax/base.py
+++ b/skrl/multi_agents/jax/base.py
@@ -60,7 +60,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
- self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
+ self.device = device
+ if type(device) == str:
+ device_type, device_index = f"{device}:0".split(':')[:2]
+ self.device = jax.devices(device_type)[int(device_index)]
# convert the models to their respective device
for _models in self.models.values():
diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py
index 45913edd..3a1f7932 100644
--- a/skrl/multi_agents/torch/ippo/ippo.py
+++ b/skrl/multi_agents/torch/ippo/ippo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.multi_agents.torch import MultiAgent
@@ -109,9 +110,18 @@ def __init__(self,
self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents}
for uid in self.possible_agents:
+ # checkpoint models
self.checkpoint_modules[uid]["policy"] = self.policies[uid]
self.checkpoint_modules[uid]["value"] = self.values[uid]
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policies[uid] is not None:
+ self.policies[uid].broadcast_parameters()
+ if self.values[uid] is not None and self.policies[uid] is not self.values[uid]:
+ self.values[uid].broadcast_parameters()
+
# configuration
self._learning_epochs = self._as_dict(self.cfg["learning_epochs"])
self._mini_batches = self._as_dict(self.cfg["mini_batches"])
@@ -437,6 +447,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizers[uid].zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ policy.reduce_parameters()
+ if policy is not value:
+ value.reduce_parameters()
if self._grad_norm_clip[uid] > 0:
if policy is value:
nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid])
@@ -453,7 +467,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
- self.schedulers[uid].step(torch.tensor(kl_divergences).mean())
+ self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.schedulers[uid].step()
diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py
index 98fff05c..ef0e7bc2 100644
--- a/skrl/multi_agents/torch/mappo/mappo.py
+++ b/skrl/multi_agents/torch/mappo/mappo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.multi_agents.torch import MultiAgent
@@ -116,9 +117,18 @@ def __init__(self,
self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents}
for uid in self.possible_agents:
+ # checkpoint models
self.checkpoint_modules[uid]["policy"] = self.policies[uid]
self.checkpoint_modules[uid]["value"] = self.values[uid]
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policies[uid] is not None:
+ self.policies[uid].broadcast_parameters()
+ if self.values[uid] is not None and self.policies[uid] is not self.values[uid]:
+ self.values[uid].broadcast_parameters()
+
# configuration
self._learning_epochs = self._as_dict(self.cfg["learning_epochs"])
self._mini_batches = self._as_dict(self.cfg["mini_batches"])
@@ -457,6 +467,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizers[uid].zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ policy.reduce_parameters()
+ if policy is not value:
+ value.reduce_parameters()
if self._grad_norm_clip[uid] > 0:
if policy is value:
nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid])
@@ -473,7 +487,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
- self.schedulers[uid].step(torch.tensor(kl_divergences).mean())
+ self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.schedulers[uid].step()
diff --git a/skrl/resources/noises/jax/base.py b/skrl/resources/noises/jax/base.py
index 2df54e20..e51e6fd3 100644
--- a/skrl/resources/noises/jax/base.py
+++ b/skrl/resources/noises/jax/base.py
@@ -31,7 +31,10 @@ def sample(self, size):
if device is None:
self.device = jax.devices()[0]
else:
- self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
+ self.device = device
+ if type(device) == str:
+ device_type, device_index = f"{device}:0".split(':')[:2]
+ self.device = jax.devices(device_type)[int(device_index)]
def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]:
"""Sample a noise with the same size (shape) as the input tensor
diff --git a/skrl/resources/preprocessors/jax/running_standard_scaler.py b/skrl/resources/preprocessors/jax/running_standard_scaler.py
index 1d319605..97d5eb63 100644
--- a/skrl/resources/preprocessors/jax/running_standard_scaler.py
+++ b/skrl/resources/preprocessors/jax/running_standard_scaler.py
@@ -95,14 +95,18 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
- self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
+ self.device = device
+ if type(device) == str:
+ device_type, device_index = f"{device}:0".split(':')[:2]
+ self.device = jax.devices(device_type)[int(device_index)]
size = self._get_space_size(size)
if self._jax:
- self.running_mean = jnp.zeros(size, dtype=jnp.float32)
- self.running_variance = jnp.ones(size, dtype=jnp.float32)
- self.current_count = jnp.ones((1,), dtype=jnp.float32)
+ with jax.default_device(self.device):
+ self.running_mean = jnp.zeros(size, dtype=jnp.float32)
+ self.running_variance = jnp.ones(size, dtype=jnp.float32)
+ self.current_count = jnp.ones((1,), dtype=jnp.float32)
else:
self.running_mean = np.zeros(size, dtype=np.float32)
self.running_variance = np.ones(size, dtype=np.float32)
diff --git a/skrl/resources/schedulers/torch/kl_adaptive.py b/skrl/resources/schedulers/torch/kl_adaptive.py
index 0c97f50f..4d7d5771 100644
--- a/skrl/resources/schedulers/torch/kl_adaptive.py
+++ b/skrl/resources/schedulers/torch/kl_adaptive.py
@@ -1,8 +1,12 @@
from typing import Optional, Union
+from packaging import version
+
import torch
from torch.optim.lr_scheduler import _LRScheduler
+from skrl import config
+
class KLAdaptiveLR(_LRScheduler):
def __init__(self,
@@ -25,6 +29,10 @@ def __init__(self,
This scheduler is only available for PPO at the moment.
Applying it to other agents will not change the learning rate
+ .. note::
+
+ In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes
+
Example::
>>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01)
@@ -50,6 +58,8 @@ def __init__(self,
:param verbose: Verbose mode (default: ``False``)
:type verbose: bool, optional
"""
+ if version.parse(torch.__version__) >= version.parse("2.2"):
+ verbose = "deprecated"
super().__init__(optimizer, last_epoch, verbose)
self.kl_threshold = kl_threshold
@@ -82,10 +92,25 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[
:type epoch: int, optional
"""
if kl is not None:
- for group in self.optimizer.param_groups:
+ # reduce (collect from all workers/processes) learning rate in distributed runs
+ if config.torch.is_distributed:
+ torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
+ kl /= config.torch.world_size
+
+ for i, group in enumerate(self.optimizer.param_groups):
+ # adjust the learning rate
+ lr = group['lr']
if kl > self.kl_threshold * self._kl_factor:
- group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr)
+ lr = max(lr / self._lr_factor, self.min_lr)
elif kl < self.kl_threshold / self._kl_factor:
- group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr)
+ lr = min(lr * self._lr_factor, self.max_lr)
+
+ # broadcast learning rate in distributed runs
+ if config.torch.is_distributed:
+ lr_tensor = torch.tensor([lr], device=config.torch.device)
+ torch.distributed.broadcast(lr_tensor, 0)
+ lr = lr_tensor.item()
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
+ # update value
+ group['lr'] = lr
+ self._last_lr[i] = lr
diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py
index a13d70a4..8784185b 100644
--- a/skrl/trainers/torch/base.py
+++ b/skrl/trainers/torch/base.py
@@ -6,7 +6,7 @@
import torch
-from skrl import logger
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.envs.wrappers.torch import Wrapper
@@ -59,6 +59,7 @@ def __init__(self,
self.headless = self.cfg.get("headless", False)
self.disable_progressbar = self.cfg.get("disable_progressbar", False)
self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True)
+ self.environment_info = self.cfg.get("environment_info", "episode")
self.initial_timestep = 0
@@ -74,6 +75,11 @@ def close_env():
self.env.close()
logger.info("Environment closed")
+ # update trainer configuration to avoid duplicated info/data in distributed runs
+ if config.torch.is_distributed:
+ if config.torch.rank:
+ self.disable_progressbar = True
+
def __str__(self) -> str:
"""Generate a string representation of the trainer
@@ -190,6 +196,12 @@ def single_agent_train(self) -> None:
timestep=timestep,
timesteps=self.timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ self.agents.track_data(f"Info / {k}", v.item())
+
# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -244,6 +256,12 @@ def single_agent_eval(self) -> None:
timesteps=self.timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ self.agents.track_data(f"Info / {k}", v.item())
+
# reset environments
if self.env.num_envs > 1:
states = next_states
@@ -304,6 +322,12 @@ def multi_agent_train(self) -> None:
timestep=timestep,
timesteps=self.timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ self.agents.track_data(f"Info / {k}", v.item())
+
# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -361,6 +385,12 @@ def multi_agent_eval(self) -> None:
timesteps=self.timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ self.agents.track_data(f"Info / {k}", v.item())
+
# reset environments
if not self.env.agents:
states, infos = self.env.reset()
diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py
index 68b9b9d8..1a086e0e 100644
--- a/skrl/trainers/torch/parallel.py
+++ b/skrl/trainers/torch/parallel.py
@@ -18,6 +18,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
+ "environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-torch]
diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py
index 49952351..67b43ca7 100644
--- a/skrl/trainers/torch/sequential.py
+++ b/skrl/trainers/torch/sequential.py
@@ -17,6 +17,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
+ "environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-torch]
@@ -116,6 +117,13 @@ def train(self) -> None:
timestep=timestep,
timesteps=self.timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ for agent in self.agents:
+ agent.track_data(f"Info / {k}", v.item())
+
# post-interaction
for agent in self.agents:
agent.post_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -184,6 +192,13 @@ def eval(self) -> None:
timesteps=self.timesteps)
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ for agent in self.agents:
+ agent.track_data(f"Info / {k}", v.item())
+
# reset environments
if terminated.any() or truncated.any():
states, infos = self.env.reset()
diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py
index c60476f1..3bfd9acc 100644
--- a/skrl/trainers/torch/step.py
+++ b/skrl/trainers/torch/step.py
@@ -17,6 +17,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
+ "environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-torch]
@@ -129,8 +130,8 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
self.env.render()
if self.num_simultaneous_agents == 1:
- # record the environments' transitions
with torch.no_grad():
+ # record the environments' transitions
self.agents.record_transition(states=self.states,
actions=actions,
rewards=rewards,
@@ -141,12 +142,18 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timestep=timestep,
timesteps=timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ self.agents.track_data(f"Info / {k}", v.item())
+
# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=timesteps)
else:
- # record the environments' transitions
with torch.no_grad():
+ # record the environments' transitions
for agent, scope in zip(self.agents, self.agents_scope):
agent.record_transition(states=self.states[scope[0]:scope[1]],
actions=actions[scope[0]:scope[1]],
@@ -158,6 +165,13 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timestep=timestep,
timesteps=timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ for agent in self.agents:
+ agent.track_data(f"Info / {k}", v.item())
+
# post-interaction
for agent in self.agents:
agent.post_interaction(timestep=timestep, timesteps=timesteps)
@@ -242,6 +256,12 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps=timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ self.agents.track_data(f"Info / {k}", v.item())
+
else:
# write data to TensorBoard
for agent, scope in zip(self.agents, self.agents_scope):
@@ -256,6 +276,13 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps=timesteps)
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps)
+ # log environment info
+ if self.environment_info in infos:
+ for k, v in infos[self.environment_info].items():
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ for agent in self.agents:
+ agent.track_data(f"Info / {k}", v.item())
+
# reset environments
if terminated.any() or truncated.any():
self.states, infos = self.env.reset()
diff --git a/skrl/utils/__init__.py b/skrl/utils/__init__.py
index e67e66cf..6ebdccf1 100644
--- a/skrl/utils/__init__.py
+++ b/skrl/utils/__init__.py
@@ -14,8 +14,14 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int:
"""
Set the seed for the random number generators
- Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1.
- Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised
+ .. note::
+
+ In distributed runs, the worker/process seed will be incremented (counting from the defined value) according to its rank
+
+ .. warning::
+
+ Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1.
+ Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised
Modified packages:
@@ -65,8 +71,12 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int:
except NotImplementedError:
seed = int(time.time() * 1000)
seed %= 2 ** 31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1
-
seed = int(seed)
+
+ # set different seeds in distributed runs
+ if config.torch.is_distributed:
+ seed += config.torch.rank
+
logger.info(f"Seed: {seed}")
# numpy
diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py
index 4fbcb588..761fdad0 100644
--- a/skrl/utils/model_instantiators/torch/__init__.py
+++ b/skrl/utils/model_instantiators/torch/__init__.py
@@ -486,7 +486,8 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g
device: Optional[Union[str, torch.device]] = None,
structure: str = "",
roles: Sequence[str] = [],
- parameters: Sequence[Mapping[str, Any]] = []) -> Model:
+ parameters: Sequence[Mapping[str, Any]] = [],
+ single_forward_pass: bool = True) -> Model:
"""Instantiate a shared model
:param observation_space: Observation/state space or shape (default: None).
@@ -505,12 +506,14 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g
:type roles: sequence of strings, optional
:param parameters: Organized list of model instantiator parameters (default: ``[]``)
:type parameters: sequence of dict, optional
+ :param single_forward_pass: Whether to perform a single forward-pass for the shared layers/network (default: ``True``)
+ :type single_forward_pass: bool
:return: Shared model instance
:rtype: Model
"""
class GaussianDeterministicModel(GaussianMixin, DeterministicMixin, Model):
- def __init__(self, observation_space, action_space, device, roles, metadata):
+ def __init__(self, observation_space, action_space, device, roles, metadata, single_forward_pass):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self,
clip_actions=metadata[0]["clip_actions"],
@@ -521,6 +524,7 @@ def __init__(self, observation_space, action_space, device, roles, metadata):
DeterministicMixin.__init__(self, clip_actions=metadata[1]["clip_actions"], role=roles[1])
self._roles = roles
+ self._single_forward_pass = single_forward_pass
self.instantiator_input_type = metadata[0]["input_shape"].value
self.instantiator_output_scales = [m["output_scale"] for m in metadata]
@@ -555,21 +559,33 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if self.instantiator_input_type == 0:
- output = self.net(inputs["states"])
+ net_inputs = inputs["states"]
elif self.instantiator_input_type == -1:
- output = self.net(inputs["taken_actions"])
+ net_inputs = inputs["taken_actions"]
elif self.instantiator_input_type == -2:
- output = self.net(torch.cat((inputs["states"], inputs["taken_actions"]), dim=1))
-
- if role == self._roles[0]:
- return self.instantiator_output_scales[0] * self.mean_net(output), self.log_std_parameter, {}
- elif role == self._roles[1]:
- return self.instantiator_output_scales[1] * self.value_net(output), {}
+ net_inputs = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)
+
+ # single shared layers/network forward-pass
+ if self._single_forward_pass:
+ if role == self._roles[0]:
+ self._shared_output = self.net(net_inputs)
+ return self.instantiator_output_scales[0] * self.mean_net(self._shared_output), self.log_std_parameter, {}
+ elif role == self._roles[1]:
+ shared_output = self.net(net_inputs) if self._shared_output is None else self._shared_output
+ self._shared_output = None
+ return self.instantiator_output_scales[1] * self.value_net(shared_output), {}
+ # multiple shared layers/network forward-pass
+ else:
+ if role == self._roles[0]:
+ return self.instantiator_output_scales[0] * self.mean_net(self.net(net_inputs)), self.log_std_parameter, {}
+ elif role == self._roles[1]:
+ return self.instantiator_output_scales[1] * self.value_net(self.net(net_inputs)), {}
# TODO: define the model using the specified structure
return GaussianDeterministicModel(observation_space=observation_space,
- action_space=action_space,
- device=device,
- roles=roles,
- metadata=parameters)
+ action_space=action_space,
+ device=device,
+ roles=roles,
+ metadata=parameters,
+ single_forward_pass=single_forward_pass)