From 2301c461c26642d3fb6482de2388568dbf8f215d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 7 Oct 2024 20:50:44 +0000 Subject: [PATCH] Move things around, add pytest.mark.slow marks Signed-off-by: Fabrice Normandin --- project/algorithms/jax_rl_example_test.py | 524 +++++++++++----------- 1 file changed, 268 insertions(+), 256 deletions(-) diff --git a/project/algorithms/jax_rl_example_test.py b/project/algorithms/jax_rl_example_test.py index 519d5aa2..d50699aa 100644 --- a/project/algorithms/jax_rl_example_test.py +++ b/project/algorithms/jax_rl_example_test.py @@ -57,259 +57,6 @@ ) logger = getLogger(__name__) -## Pytorch-Lightning wrapper around this learner: - - -class PPOLightningModule(lightning.LightningModule): - """Uses the same code as [project.algorithms.jax_rl_example.JaxRLExample][], but the training - loop is run with pytorch-lightning. - - This is currently only meant to be used to compare the difference fully-jitted training loop - and lightning. - """ - - def __init__( - self, - learner: JaxRLExample, - ts: PPOState, - ): - # https://github.com/keraJLi/rejax/blob/a1428ad3d661e31985c5c19460cec70bc95aef6e/configs/gymnax/pendulum.yaml#L1 - - super().__init__() - self.learner = learner - self.ts = ts - - self.save_hyperparameters(hparams_to_dict(learner)) - self.actor_params = torch.nn.ParameterList( - jax.tree.leaves( - jax.tree.map( - torch_jax_interop.to_torch.jax_to_torch_tensor, - self.ts.actor_ts.params, - ) - ) - ) - self.critic_params = torch.nn.ParameterList( - jax.tree.leaves( - jax.tree.map( - torch_jax_interop.to_torch.jax_to_torch_tensor, - self.ts.critic_ts.params, - ) - ) - ) - - self.automatic_optimization = False - - iteration_steps = self.learner.hp.num_envs * self.learner.hp.num_steps - # number of "iterations" (collecting batches of episodes in the environment) per epoch. - self.num_train_iterations = np.ceil(self.learner.hp.eval_freq / iteration_steps).astype( - int - ) - - @override - def training_step(self, batch: torch.Tensor, batch_idx: int): - start = time.perf_counter() - with jax.disable_jit(self.learner.hp.debug): - algo_struct = self.learner - self.ts, train_metrics = algo_struct.fused_training_step(batch_idx, self.ts) - - duration = time.perf_counter() - start - logger.debug(f"Training step took {duration:.1f} seconds.") - actor_losses = train_metrics.actor_losses - critic_losses = train_metrics.critic_losses - self.log("train/actor_loss", actor_losses.mean().item(), logger=True, prog_bar=True) - self.log("train/critic_loss", critic_losses.mean().item(), logger=True, prog_bar=True) - - updates_per_second = ( - self.learner.hp.num_epochs * self.learner.hp.num_minibatches - ) / duration - self.log("train/updates_per_second", updates_per_second, logger=True, prog_bar=True) - minibatch_size = ( - self.learner.hp.num_envs * self.learner.hp.num_steps - ) // self.learner.hp.num_minibatches - samples_per_update = minibatch_size - self.log( - "train/samples_per_second", - updates_per_second * samples_per_update, - logger=True, - prog_bar=True, - on_step=True, - ) - - # for jax_param, torch_param in zip( - # jax.tree.leaves(self.train_state.actor_ts.params), self.actor_params - # ): - # torch_param.set_(torch_jax_interop.to_torch.jax_to_torch_tensor(jax_param)) - - # for jax_param, torch_param in zip( - # jax.tree.leaves(self.train_state.critic_ts.params), self.critic_params - # ): - # torch_param.set_(torch_jax_interop.to_torch.jax_to_torch_tensor(jax_param)) - - return - - @override - def train_dataloader(self) -> Iterable[Trajectory]: - # BUG: what's probably happening is that the dataloader keeps getting batches with the - # initial train state! - from torch.utils.data import TensorDataset - - dataset = TensorDataset(torch.arange(self.num_train_iterations, device=self.device)) - return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False, collate_fn=None) - - def val_dataloader(self) -> Any: - # todo: unsure what this should be yielding.. - from torch.utils.data import TensorDataset - - dataset = TensorDataset(torch.arange(1, device=self.device)) - return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False, collate_fn=None) - - def validation_step(self, batch: int, batch_index: int): - # self.learner.eval_callback() - # return # skip the rest for now while we compare the performance? - eval_metrics = self.learner.eval_callback(ts=self.ts) - episode_lengths = eval_metrics.episode_length - cumulative_rewards = eval_metrics.cumulative_reward - self.log("val/episode_lengths", episode_lengths.mean().item(), batch_size=1) - self.log("val/rewards", cumulative_rewards.mean().item(), batch_size=1) - - def on_train_epoch_start(self) -> None: - if not isinstance(self.learner.env, gymnax.environments.environment.Environment): - return - assert self.trainer.log_dir is not None - gif_path = Path(self.trainer.log_dir) / f"epoch_{self.current_epoch}.gif" - self.learner.visualize(ts=self.ts, gif_path=gif_path) - return # skip the rest for now while we compare the performance - actor = make_actor(ts=self.train_state, hp=self.hp) - render_episode( - actor=actor, - env=self.env, - env_params=self.env_params, - gif_path=gif_path, - num_steps=200, - ) - return super().on_train_epoch_end() - - @override - def configure_optimizers(self) -> Any: - # todo: Note, this one isn't used atm! - from torch.optim.adam import Adam - - return Adam(self.parameters(), lr=1e-3) - - @override - def configure_callbacks(self) -> list[lightning.Callback]: - return [RlThroughputCallback()] - - @override - def transfer_batch_to_device( - self, batch: TrajectoryWithLastObs | int, device: torch.device, dataloader_idx: int - ) -> TrajectoryWithLastObs | int: - if isinstance(batch, int): - # FIXME: valid dataloader currently just yields ints, not trajectories. - return batch - if isinstance(batch, list) and len(batch) == 1: - # FIXME: train dataloader currently just yields ints, not trajectories. - return batch - - _batch_jax_devices = batch.trajectories.obs.devices() - assert len(_batch_jax_devices) == 1 - batch_jax_device = _batch_jax_devices.pop() - torch_self_device = device - if ( - torch_self_device.type == "cuda" - and "cuda" in str(batch_jax_device) - and (torch_self_device.index == -1 or torch_self_device.index == batch_jax_device.id) - ): - # All good, both are on the same GPU. - return batch - - jax_self_device = torch_jax_interop.to_jax.torch_to_jax_device(torch_self_device) - return jax.tree.map(functools.partial(jax.device_put, device=jax_self_device), batch) - - -class RlThroughputCallback(MeasureSamplesPerSecondCallback): - """A callback to measure the throughput of RL algorithms.""" - - def __init__(self, num_optimizers: int | None = 1): - super().__init__(num_optimizers=num_optimizers) - self.total_transitions = 0 - self.total_episodes = 0 - self._start = time.perf_counter() - self._updates = 0 - - @override - def on_fit_start( - self, - trainer: lightning.Trainer, - pl_module: lightning.LightningModule, - ) -> None: - super().on_fit_start(trainer, pl_module) - self.total_transitions = 0 - self.total_episodes = 0 - self._start = time.perf_counter() - - @override - def on_train_batch_end( - self, - trainer: lightning.Trainer, - pl_module: lightning.LightningModule, - outputs: dict[str, torch.Tensor], - batch: TrajectoryWithLastObs, - batch_index: int, - ) -> None: - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_index) - if not isinstance(batch, TrajectoryWithLastObs): - return - episodes = batch.trajectories - assert episodes.obs.shape - num_episodes = episodes.obs.shape[0] - num_transitions = np.prod(episodes.obs.shape[:2]) - self.total_episodes += num_episodes - self.total_transitions += num_transitions - steps_per_second = self.total_transitions / (time.perf_counter() - self._start) - updates_per_second = (self._updates) / (time.perf_counter() - self._start) - episodes_per_second = self.total_episodes / (time.perf_counter() - self._start) - logger.info( - f"Total transitions: {self.total_transitions}, total episodes: {self.total_episodes}" - ) - # print(f"Steps per second: {steps_per_second}") - logger.info(f"Steps per second: {steps_per_second}") - logger.info(f"Episodes per second: {episodes_per_second}") - logger.info(f"Updates per second: {updates_per_second}") - - @override - def get_num_samples(self, batch: TrajectoryWithLastObs) -> int: - if isinstance(batch, int): # fixme - return 1 - return int(np.prod(batch.trajectories.obs.shape[:2]).item()) - - @override - def on_fit_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None: - super().on_fit_end(trainer, pl_module) - - def log( - self, - name: str, - value: Any, - module: JaxRLExample, - trainer: lightning.Trainer | JaxTrainer, - **kwargs, - ): - # Used to possibly customize how the values are logged (e.g. for non-LightningModules). - # By default, uses the LightningModule.log method. - # TODO: Somehow log the metrics without an actual trainer? - # Should we create a Trainer / LightningModule "facade" that the callbacks can interact with? - if trainer.logger: - trainer.logger.log_metrics({name: value}, step=trainer.global_step, **kwargs) - - # if trainer.progress_bar_callback: - # trainer.progress_bar_callback.log_metrics({name: value}, step=trainer.global_step, **kwargs) - # return trainer.logger.log_metrics().log( - # name, - # value, - # **kwargs, - # ) - @pytest.fixture(params=["Pendulum-v1"]) def env_id(request: pytest.FixtureRequest) -> str: @@ -427,20 +174,20 @@ def _add_gitignore_if_needed(original_datadir: Path): gitignore_file.parent.mkdir(exist_ok=True, parents=True) gitignore_file.write_text("*.gif\n") - +@pytest.mark.slow +@pytest.mark.timeout(35) def test_train_ours( algo: JaxRLExample, rng: chex.PRNGKey, original_datadir: Path, tmp_path: Path, file_regression: FileRegressionFixture, - # ndarrays_regression: NDArraysRegressionFixture, tensor_regression: TensorRegressionFixture, ): + """Test our tweaked version of `rejax.PPO` algorithm using the `JaxRLExample.train` method.""" _add_gitignore_if_needed(original_datadir) train_state, evaluations = algo.train(rng=rng) - # ndarrays_regression.check(dataclasses.asdict(evals)) tensor_regression.check( jax.tree.map(torch_jax_interop.jax_to_torch, dataclasses.asdict(evaluations)) ) @@ -450,6 +197,8 @@ def test_train_ours( file_regression.check(_gif_path.read_bytes(), binary=True, extension=".gif") +@pytest.mark.slow +@pytest.mark.timeout(35) def test_train_ours_with_trainer( algo: JaxRLExample, rng: chex.PRNGKey, @@ -460,6 +209,8 @@ def test_train_ours_with_trainer( # ndarrays_regression: NDArraysRegressionFixture, tensor_regression: TensorRegressionFixture, ): + """Test our tweaked version of `rejax.PPO` algorithm using a `JaxTrainer`.""" + _add_gitignore_if_needed(original_datadir) train_fn = trainer.fit @@ -477,6 +228,8 @@ def test_train_ours_with_trainer( file_regression.check(_gif_path.read_bytes(), binary=True, extension=".gif") +@pytest.mark.slow +@pytest.mark.timeout(35) def test_rejax( algo: JaxRLExample, rng: chex.PRNGKey, @@ -526,6 +279,8 @@ def test_rejax( # Sort-of slow. +@pytest.mark.slow +@pytest.mark.timeout(70) @pytest.mark.parametrize( "with_callbacks", [ @@ -604,6 +359,261 @@ def test_ours_with_vmap( gif_path=figures_dir / "pure_jax_avg.gif", ) +## Pytorch-Lightning wrapper around this learner: + +# Don't allow tests to run for more than 5 seconds. +# pytestmark = pytest.mark.timeout(5) + +class PPOLightningModule(lightning.LightningModule): + """Uses the same code as [project.algorithms.jax_rl_example.JaxRLExample][], but the training + loop is run with pytorch-lightning. + + This is currently only meant to be used to compare the difference fully-jitted training loop + and lightning. + """ + + def __init__( + self, + learner: JaxRLExample, + ts: PPOState, + ): + # https://github.com/keraJLi/rejax/blob/a1428ad3d661e31985c5c19460cec70bc95aef6e/configs/gymnax/pendulum.yaml#L1 + + super().__init__() + self.learner = learner + self.ts = ts + + self.save_hyperparameters(hparams_to_dict(learner)) + self.actor_params = torch.nn.ParameterList( + jax.tree.leaves( + jax.tree.map( + torch_jax_interop.to_torch.jax_to_torch_tensor, + self.ts.actor_ts.params, + ) + ) + ) + self.critic_params = torch.nn.ParameterList( + jax.tree.leaves( + jax.tree.map( + torch_jax_interop.to_torch.jax_to_torch_tensor, + self.ts.critic_ts.params, + ) + ) + ) + + self.automatic_optimization = False + + iteration_steps = self.learner.hp.num_envs * self.learner.hp.num_steps + # number of "iterations" (collecting batches of episodes in the environment) per epoch. + self.num_train_iterations = np.ceil(self.learner.hp.eval_freq / iteration_steps).astype( + int + ) + + @override + def training_step(self, batch: torch.Tensor, batch_idx: int): + start = time.perf_counter() + with jax.disable_jit(self.learner.hp.debug): + algo_struct = self.learner + self.ts, train_metrics = algo_struct.fused_training_step(batch_idx, self.ts) + + duration = time.perf_counter() - start + logger.debug(f"Training step took {duration:.1f} seconds.") + actor_losses = train_metrics.actor_losses + critic_losses = train_metrics.critic_losses + self.log("train/actor_loss", actor_losses.mean().item(), logger=True, prog_bar=True) + self.log("train/critic_loss", critic_losses.mean().item(), logger=True, prog_bar=True) + + updates_per_second = ( + self.learner.hp.num_epochs * self.learner.hp.num_minibatches + ) / duration + self.log("train/updates_per_second", updates_per_second, logger=True, prog_bar=True) + minibatch_size = ( + self.learner.hp.num_envs * self.learner.hp.num_steps + ) // self.learner.hp.num_minibatches + samples_per_update = minibatch_size + self.log( + "train/samples_per_second", + updates_per_second * samples_per_update, + logger=True, + prog_bar=True, + on_step=True, + ) + + # for jax_param, torch_param in zip( + # jax.tree.leaves(self.train_state.actor_ts.params), self.actor_params + # ): + # torch_param.set_(torch_jax_interop.to_torch.jax_to_torch_tensor(jax_param)) + + # for jax_param, torch_param in zip( + # jax.tree.leaves(self.train_state.critic_ts.params), self.critic_params + # ): + # torch_param.set_(torch_jax_interop.to_torch.jax_to_torch_tensor(jax_param)) + + return + + @override + def train_dataloader(self) -> Iterable[Trajectory]: + # BUG: what's probably happening is that the dataloader keeps getting batches with the + # initial train state! + from torch.utils.data import TensorDataset + + dataset = TensorDataset(torch.arange(self.num_train_iterations, device=self.device)) + return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False, collate_fn=None) + + def val_dataloader(self) -> Any: + # todo: unsure what this should be yielding.. + from torch.utils.data import TensorDataset + + dataset = TensorDataset(torch.arange(1, device=self.device)) + return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False, collate_fn=None) + + def validation_step(self, batch: int, batch_index: int): + # self.learner.eval_callback() + # return # skip the rest for now while we compare the performance? + eval_metrics = self.learner.eval_callback(ts=self.ts) + episode_lengths = eval_metrics.episode_length + cumulative_rewards = eval_metrics.cumulative_reward + self.log("val/episode_lengths", episode_lengths.mean().item(), batch_size=1) + self.log("val/rewards", cumulative_rewards.mean().item(), batch_size=1) + + def on_train_epoch_start(self) -> None: + if not isinstance(self.learner.env, gymnax.environments.environment.Environment): + return + assert self.trainer.log_dir is not None + gif_path = Path(self.trainer.log_dir) / f"epoch_{self.current_epoch}.gif" + self.learner.visualize(ts=self.ts, gif_path=gif_path) + return # skip the rest for now while we compare the performance + actor = make_actor(ts=self.train_state, hp=self.hp) + render_episode( + actor=actor, + env=self.env, + env_params=self.env_params, + gif_path=gif_path, + num_steps=200, + ) + return super().on_train_epoch_end() + + @override + def configure_optimizers(self) -> Any: + # todo: Note, this one isn't used atm! + from torch.optim.adam import Adam + + return Adam(self.parameters(), lr=1e-3) + + @override + def configure_callbacks(self) -> list[lightning.Callback]: + return [RlThroughputCallback()] + + @override + def transfer_batch_to_device( + self, batch: TrajectoryWithLastObs | int, device: torch.device, dataloader_idx: int + ) -> TrajectoryWithLastObs | int: + if isinstance(batch, int): + # FIXME: valid dataloader currently just yields ints, not trajectories. + return batch + if isinstance(batch, list) and len(batch) == 1: + # FIXME: train dataloader currently just yields ints, not trajectories. + return batch + + _batch_jax_devices = batch.trajectories.obs.devices() + assert len(_batch_jax_devices) == 1 + batch_jax_device = _batch_jax_devices.pop() + torch_self_device = device + if ( + torch_self_device.type == "cuda" + and "cuda" in str(batch_jax_device) + and (torch_self_device.index == -1 or torch_self_device.index == batch_jax_device.id) + ): + # All good, both are on the same GPU. + return batch + + jax_self_device = torch_jax_interop.to_jax.torch_to_jax_device(torch_self_device) + return jax.tree.map(functools.partial(jax.device_put, device=jax_self_device), batch) + + +class RlThroughputCallback(MeasureSamplesPerSecondCallback): + """A callback to measure the throughput of RL algorithms.""" + + def __init__(self, num_optimizers: int | None = 1): + super().__init__(num_optimizers=num_optimizers) + self.total_transitions = 0 + self.total_episodes = 0 + self._start = time.perf_counter() + self._updates = 0 + + @override + def on_fit_start( + self, + trainer: lightning.Trainer, + pl_module: lightning.LightningModule, + ) -> None: + super().on_fit_start(trainer, pl_module) + self.total_transitions = 0 + self.total_episodes = 0 + self._start = time.perf_counter() + + @override + def on_train_batch_end( + self, + trainer: lightning.Trainer, + pl_module: lightning.LightningModule, + outputs: dict[str, torch.Tensor], + batch: TrajectoryWithLastObs, + batch_index: int, + ) -> None: + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_index) + if not isinstance(batch, TrajectoryWithLastObs): + return + episodes = batch.trajectories + assert episodes.obs.shape + num_episodes = episodes.obs.shape[0] + num_transitions = np.prod(episodes.obs.shape[:2]) + self.total_episodes += num_episodes + self.total_transitions += num_transitions + steps_per_second = self.total_transitions / (time.perf_counter() - self._start) + updates_per_second = (self._updates) / (time.perf_counter() - self._start) + episodes_per_second = self.total_episodes / (time.perf_counter() - self._start) + logger.info( + f"Total transitions: {self.total_transitions}, total episodes: {self.total_episodes}" + ) + # print(f"Steps per second: {steps_per_second}") + logger.info(f"Steps per second: {steps_per_second}") + logger.info(f"Episodes per second: {episodes_per_second}") + logger.info(f"Updates per second: {updates_per_second}") + + @override + def get_num_samples(self, batch: TrajectoryWithLastObs) -> int: + if isinstance(batch, int): # fixme + return 1 + return int(np.prod(batch.trajectories.obs.shape[:2]).item()) + + @override + def on_fit_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None: + super().on_fit_end(trainer, pl_module) + + def log( + self, + name: str, + value: Any, + module: JaxRLExample, + trainer: lightning.Trainer | JaxTrainer, + **kwargs, + ): + # Used to possibly customize how the values are logged (e.g. for non-LightningModules). + # By default, uses the LightningModule.log method. + # TODO: Somehow log the metrics without an actual trainer? + # Should we create a Trainer / LightningModule "facade" that the callbacks can interact with? + if trainer.logger: + trainer.logger.log_metrics({name: value}, step=trainer.global_step, **kwargs) + + # if trainer.progress_bar_callback: + # trainer.progress_bar_callback.log_metrics({name: value}, step=trainer.global_step, **kwargs) + # return trainer.logger.log_metrics().log( + # name, + # value, + # **kwargs, + # ) + @pytest.fixture def lightning_trainer(max_epochs: int, tmp_path: Path): @@ -619,6 +629,8 @@ def lightning_trainer(max_epochs: int, tmp_path: Path): # reducing the max_epochs from 75 down to 3 because it's just wayyy too slow otherwise. +@pytest.mark.slow +@pytest.mark.timeout(80) @pytest.mark.parametrize("max_epochs", [3], indirect=True) def test_lightning( algo: JaxRLExample,