From f845cfe64af1571a128f073c4cb72584e039dd5d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 7 Oct 2024 15:13:16 +0000 Subject: [PATCH] Simplify docs in `jax_trainer.py` Signed-off-by: Fabrice Normandin --- project/trainers/jax_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/project/trainers/jax_trainer.py b/project/trainers/jax_trainer.py index c41da079..79c51bc2 100644 --- a/project/trainers/jax_trainer.py +++ b/project/trainers/jax_trainer.py @@ -37,8 +37,8 @@ class JaxModule(Protocol[Ts, _B, _MetricsT]): """A protocol for algorithms that can be trained by the `JaxTrainer`. - The [JaxRLExample][project.algorithms.jax_rl_example.JaxRLExample] class is an example of a - class that follows this structure and can be trained with a [JaxTrainer][project.algorithms.jax_trainer.JaxTrainer]. + The `JaxRLExample` is an example that follows this structure and can be trained with a + `JaxTrainer`. """ def init_train_state(self, rng: chex.PRNGKey) -> Ts: @@ -95,8 +95,8 @@ class JaxTrainer(flax.struct.PyTreeNode): ## Assumptions: - - The algo object must match the [JaxModule][project.algorithms.jax_trainer.JaxModule] protocol (in - other words, it should implement its methods). + - The algo object must match the `JaxModule` protocol (in other words, it should implement its + methods). ## Training loop