Skip to content

Commit

Permalink
Simplify docs in jax_trainer.py
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Oct 7, 2024
1 parent 51356c8 commit f845cfe
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions project/trainers/jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
class JaxModule(Protocol[Ts, _B, _MetricsT]):
"""A protocol for algorithms that can be trained by the `JaxTrainer`.
The [JaxRLExample][project.algorithms.jax_rl_example.JaxRLExample] class is an example of a
class that follows this structure and can be trained with a [JaxTrainer][project.algorithms.jax_trainer.JaxTrainer].
The `JaxRLExample` is an example that follows this structure and can be trained with a
`JaxTrainer`.
"""

def init_train_state(self, rng: chex.PRNGKey) -> Ts:
Expand Down Expand Up @@ -95,8 +95,8 @@ class JaxTrainer(flax.struct.PyTreeNode):
## Assumptions:
- The algo object must match the [JaxModule][project.algorithms.jax_trainer.JaxModule] protocol (in
other words, it should implement its methods).
- The algo object must match the `JaxModule` protocol (in other words, it should implement its
methods).
## Training loop
Expand Down

0 comments on commit f845cfe

Please sign in to comment.