From 52929ba5f4c4496d00063194652616d44984d0c7 Mon Sep 17 00:00:00 2001 From: youliang Date: Sat, 30 Dec 2023 07:34:08 +0800 Subject: [PATCH] support tfrecord for datastore (#15) * simple exporter wip Signed-off-by: youliang * commit files Signed-off-by: youliang * cleaner impl Signed-off-by: youliang * use oxe_logger for rlds recording Signed-off-by: youliang * with make_datastore method Signed-off-by: youliang * minor cleanup and comments Signed-off-by: youliang * nit setup.py Signed-off-by: youliang * try fix ci Signed-off-by: youliang --------- Signed-off-by: youliang --- .github/workflows/python-app.yml | 2 +- .gitignore | 1 + LICENSE | 21 +++++ README.md | 13 ++- edgeml/data/jaxrl_data_store.py | 136 +++++++++++++++++++++++++++++-- edgeml/data/tfds.py | 116 ++++++++++++++++++++++++++ edgeml/data/trajectory_buffer.py | 13 ++- edgeml/internal/utils.py | 1 + edgeml/tests/test_tfds.py | 126 ++++++++++++++++++++++++++++ examples/async_learner_actor.py | 98 +++++++++++++++------- examples/jaxrl_m_common.py | 54 +----------- 11 files changed, 489 insertions(+), 92 deletions(-) create mode 100644 LICENSE create mode 100644 edgeml/data/tfds.py create mode 100644 edgeml/tests/test_tfds.py diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 444309e..35d516a 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -29,7 +29,7 @@ jobs: # Install the current repo (edgeml) - name: Install edgeml run: | - pip install . + pip install -e . - name: Lint with flake8 run: | diff --git a/.gitignore b/.gitignore index 259e285..5993925 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ *.so *.pyc wandb/ +logs/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3b12512 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 You Liang Tan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 54d88ae..ab46b02 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,20 @@ python3 examples/run_data_store.py ```bash # Indicate --learner or --actor mode, no tag means async multithreaded mode # provide --ip for remote connection, default is localhost -# use --render to render the gym env -# use --use_traj_buffer to use trajectory buffer instead of replay buffer python3 examples/async_learner_actor.py ``` +**More option flags**: + - `--env`: gym env name, default is `HalfCheetah-v4` + - `--learner`: run learner mode + - `--actor`: run actor mode + - `--ip`: ip address of the remote server + - `--render`: render the gym env + - `--use_traj_buffer`: use trajectory buffer instead of replay buffer + - `--rlds_log_dir`: directory to save the tfrecords for [RLDS](https://github.com/google-research/rlds) + +NOTE: rlds logger requires installation of [oxe_envlogger](https://github.com/rail-berkeley/oxe_envlogger) + --- ## Architecture diff --git a/edgeml/data/jaxrl_data_store.py b/edgeml/data/jaxrl_data_store.py index 11a78f8..105c5a2 100644 --- a/edgeml/data/jaxrl_data_store.py +++ b/edgeml/data/jaxrl_data_store.py @@ -2,14 +2,27 @@ # NOTE: this requires jaxrl_m to be installed: # https://github.com/rail-berkeley/jaxrl_minimal +from __future__ import annotations + from threading import Lock from typing import List, Optional from edgeml.data.data_store import DataStoreBase from edgeml.data.trajectory_buffer import TrajectoryBuffer, DataShape +from edgeml.data.sampler import LatestSampler, SequenceSampler import gym import jax +import chex +import numpy as np from jaxrl_m.data.replay_buffer import ReplayBuffer +import tensorflow as tf + +# import oxe_envlogger if it is installed +try: + from oxe_envlogger.rlds_logger import RLDSLogger, RLDSStepType +except ImportError: + print("rlds logger is not installed, install it if required: " + "https://github.com/rail-berkeley/oxe_envlogger ") ############################################################################## @@ -22,6 +35,7 @@ def __init__( device: Optional[jax.Device] = None, seed: int = 0, min_trajectory_length: int = 2, + rlds_logger: Optional[RLDSLogger] = None, ): TrajectoryBuffer.__init__( self, @@ -33,11 +47,36 @@ def __init__( ) DataStoreBase.__init__(self, capacity) self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger # ensure thread safety - def insert(self, *args, **kwargs): + def insert(self, data): with self._lock: - super(TrajectoryBufferDataStore, self).insert(*args, **kwargs) + super(TrajectoryBufferDataStore, self).insert(data) + + if self._logger: + # handle restart when it was done before + if self.step_type == RLDSStepType.TERMINATION: + self.step_type = RLDSStepType.RESTART + elif self.step_type == RLDSStepType.TRUNCATION: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["end_of_trajectory"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: not obs, but next_obs + reward=data["rewards"], + step_type=self.step_type, + ) # ensure thread safety def sample(self, *args, **kwargs): @@ -59,15 +98,41 @@ def __init__( observation_space: gym.Space, action_space: gym.Space, capacity: int, + rlds_logger: Optional[RLDSLogger] = None, ): ReplayBuffer.__init__(self, observation_space, action_space, capacity) DataStoreBase.__init__(self, capacity) + self._insert_seq_id = 0 # keeps increasing self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger # ensure thread safety - def insert(self, *args, **kwargs): + def insert(self, data): with self._lock: - super(ReplayBufferDataStore, self).insert(*args, **kwargs) + super(ReplayBufferDataStore, self).insert(data) + self._insert_seq_id += 1 + + # add data to the rlds logger + # TODO: the current impl of ReplayBuffer doesn't support + # proper truncation of the trajectory + if self._logger: + if self.step_type == RLDSStepType.TERMINATION: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: check if this is correct + reward=data["rewards"], + step_type=self.step_type, + ) # ensure thread safety def sample(self, *args, **kwargs): @@ -76,8 +141,69 @@ def sample(self, *args, **kwargs): # NOTE: method for DataStoreBase def latest_data_id(self): - return self._insert_index + return self._insert_seq_id # NOTE: method for DataStoreBase def get_latest_data(self, from_id: int): raise NotImplementedError # TODO + + def __del__(self): + if self._logger: + self._logger.close() + +############################################################################## + + +def make_default_trajectory_buffer( + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + device: Optional[jax.Device] = None, + rlds_logger: Optional[RLDSLogger] = None, +): + replay_buffer = TrajectoryBufferDataStore( + capacity=capacity, + data_shapes=[ + DataShape("observations", observation_space.shape, observation_space.dtype), + DataShape("next_observations", observation_space.shape, observation_space.dtype), + DataShape("actions", action_space.shape, action_space.dtype), + DataShape("rewards", (), np.float64), + DataShape("masks", (), np.float64), + DataShape("end_of_trajectory", (), dtype="bool"), + ], + min_trajectory_length=2, + device=device, + rlds_logger=rlds_logger, + ) + + @jax.jit + def transform_rl_data(batch, mask): + batch_size = jax.tree_util.tree_flatten(batch)[0][0].shape[0] + chex.assert_tree_shape_prefix(batch["observations"], (batch_size, 2)) + chex.assert_tree_shape_prefix(mask["observations"], (batch_size, 2)) + return { + **batch, + "observations": batch["observations"][:, 0], + "next_observations": batch["observations"][:, 1], + }, { + **mask, + "observations": mask["observations"][:, 0], + "next_observations": mask["observations"][:, 1], + } + + replay_buffer.register_sample_config( + "training", + samplers={ + "observations": SequenceSampler( + squeeze=False, begin=0, end=2, source="observations" + ), + "actions": LatestSampler(), + "rewards": LatestSampler(), + "masks": LatestSampler(), + "next_observations": LatestSampler(), + "end_of_trajectory": LatestSampler(), + }, + transform=transform_rl_data, + sample_range=(0, 2), + ) + return replay_buffer diff --git a/edgeml/data/tfds.py b/edgeml/data/tfds.py new file mode 100644 index 0000000..6306149 --- /dev/null +++ b/edgeml/data/tfds.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import gym +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from edgeml.data.jaxrl_data_store import ReplayBufferDataStore +from edgeml.data.jaxrl_data_store import make_default_trajectory_buffer + +############################################################################## + + +def make_datastore(dataset_dir, capacity, type="replay_buffer"): + """ + Load an RLDS dataset from the specified directory and populate it + into the given datastore. + + Args: + - dataset_dir: Directory where the RLDS dataset is stored. + - capacity: Capacity of the replay buffer. + - type: supported types are "replay_buffer" and "trajectory_buffer" + + Returns: + - replay_buffer: Replay buffer populated with the RLDS dataset. + """ + # Load the dataset + dataset = tfds.builder_from_directory(dataset_dir).as_dataset(split='all') + + # get the obs and action spec from the dataset + obs_tensor_spec = dataset.element_spec["steps"].element_spec["observation"] + action_tensor_spec = dataset.element_spec['steps'].element_spec['action'] + + print("obs spec: ", obs_tensor_spec) + print("action spec: ", action_tensor_spec) + + if type == "replay_buffer": + datastore = ReplayBufferDataStore( + observation_space=tensor_spec_to_gym_space(obs_tensor_spec), + action_space=tensor_spec_to_gym_space(action_tensor_spec), + capacity=capacity, + ) + elif type == "trajectory_buffer": + datastore = make_default_trajectory_buffer( + observation_space=tensor_spec_to_gym_space(obs_tensor_spec), + action_space=tensor_spec_to_gym_space(action_tensor_spec), + capacity=capacity, + ) + else: + raise ValueError(f"Unsupported type: {type}") + + # Iterate over episodes in the dataset + for episode in dataset: + steps = episode['steps'] + obs = None + # Iterate through steps in the episode + for i, step in enumerate(steps): + if i == 0: + obs = get_numpy_from_tensor(step['observation']) + continue + + # Extract relevant data from the step + next_obs = get_numpy_from_tensor(step['observation']) + action = get_numpy_from_tensor(step['action']) + reward = step.get('reward', 0).numpy() # Defaulting to 0 if 'reward' key is missing + terminate = step['is_terminal'].numpy() # or is_last + truncate = step['is_last'].numpy() # truncate is not avail in the ReplayBuffer + + data = dict( + observations=obs, + next_observations=next_obs, + actions=action, + rewards=reward, + masks=1 - terminate, # 1 is transition, 0 is terminal + ) + + if type == "trajectory_buffer": + data["end_of_trajectory"] = truncate + + # Insert data into the replay buffer + datastore.insert(data) + obs = next_obs + return datastore + +############################################################################## + + +def tensor_spec_to_gym_space(tensor_spec: tf.data.experimental.TensorSpec): + """ + Convert a TensorSpec to a gym.Space, should support dict and box + """ + if isinstance(tensor_spec, tf.TensorSpec): + return gym.spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=tensor_spec.shape, + dtype=tensor_spec.dtype.as_numpy_dtype, + ) + elif isinstance(tensor_spec, dict): + return gym.spaces.Dict( + { + k: tensor_spec_to_gym_space(v) + for k, v in tensor_spec.items() + } + ) + else: + raise TypeError(f"Unsupported tensor spec type: {type(tensor_spec)}") + + +def get_numpy_from_tensor(tensor): + """ + Convert a tensor to numpy + """ + if isinstance(tensor, dict): + return {k: get_numpy_from_tensor(v) for k, v in tensor.items()} + return tensor.numpy() diff --git a/edgeml/data/trajectory_buffer.py b/edgeml/data/trajectory_buffer.py index 38d9856..25e9357 100644 --- a/edgeml/data/trajectory_buffer.py +++ b/edgeml/data/trajectory_buffer.py @@ -101,10 +101,19 @@ def register_sample_config( def insert(self, data: Dict[str, jax.Array], end_of_trajectory: bool = False): """ Insert a single data point into the data store. - # TODO: end_of_trajectory tag defined in data? + + the min required data entry is: + data = dict( + observations=observation_data, # np.ndarray or dict + next_observations=next_observation_data, # np.ndarray or dict + actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype), + rewards=np.empty((capacity,), dtype=np.float32), + masks=np.empty((capacity,), dtype=bool), # is terminal + end_of_trajectory=False, # is last + ) """ end_of_trajectory = data.get( - "end_of_trajectory", end_of_trajectory) # TODO + "end_of_trajectory", end_of_trajectory) # TODO: if not exist, assume False # Grab the metadata of the sample we're overwriting real_insert_idx = self._insert_idx % self.capacity diff --git a/edgeml/internal/utils.py b/edgeml/internal/utils.py index 8402e48..f7296be 100644 --- a/edgeml/internal/utils.py +++ b/edgeml/internal/utils.py @@ -50,6 +50,7 @@ def make_compression_method(compression: str) -> Tuple[Callable, Callable]: :return: compress, decompress functions def compress(object) -> bytes def decompress(data) -> object + TODO: support msgpack """ if compression == 'lz4': def compress(data): return lz4.frame.compress(pickle.dumps(data)) diff --git a/edgeml/tests/test_tfds.py b/edgeml/tests/test_tfds.py new file mode 100644 index 0000000..670fbef --- /dev/null +++ b/edgeml/tests/test_tfds.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +from edgeml.data.tfds import make_datastore +from edgeml.data.jaxrl_data_store import ReplayBufferDataStore +from edgeml.data.jaxrl_data_store import make_default_trajectory_buffer + +from oxe_envlogger.rlds_logger import RLDSLogger + +import gym +from gym import spaces +import numpy as np +import os + +class CustomEnv(gym.Env): + """ + A custom environment that uses a dictionary for the observation space. + """ + + def __init__(self): + super(CustomEnv, self).__init__() + self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32) + self.observation_space = spaces.Dict({ + 'position': spaces.Box(low=0, high=10, shape=(1,), dtype=np.int32), + 'velocity': spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32) + }) + + def step(self, action): + observation = { + 'position': np.random.randint(0, 10, size=(1,)), + 'velocity': np.random.uniform(-1, 1, size=(1,)) + } + # Example reward, done, and info + reward = 1.0 + done = False + info = {} + return observation, reward, done, info + + def reset(self): + # Example initial observation + observation = { + 'position': np.random.randint(0, 10, size=(1,)), + 'velocity': np.random.uniform(-1, 1, size=(1,)) + } + return observation, {} + + +def run_rlds_logger(env, capacity=20, type="replay_buffer"): + log_dir = "logs/test_rlds_env" + + logger = RLDSLogger( + observation_space=env.observation_space, + action_space=env.action_space, + dataset_name="test_rlds_env", + directory=log_dir, + max_episodes_per_file=5, # TODO: arbitrary number + ) + + if type == "replay_buffer": + data_store = ReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + rlds_logger=logger, + ) + elif type == "trajectory_buffer": + data_store = make_default_trajectory_buffer( + env.observation_space, + env.action_space, + capacity=capacity, + rlds_logger=logger, + ) + + # create some fake data + sample_obs = env.reset()[0] + action_shape = env.action_space.shape + sample_action = np.random.randn(*action_shape) + + print("inserting data") + for j in range(10): # 10 episodes + for i in range(15): # arbitrary number of 15 samples + done = 0 if i < 14 else 1 # last sample is terminal + + sample = dict( + observations=sample_obs, + next_observations=sample_obs, + actions=sample_action, + rewards=np.random.randn(), + masks=1 - done, # 1 is transition, 0 is terminal + ) + + if type == "trajectory_buffer": + sample["end_of_trajectory"] = False + + data_store.insert(sample) + logger.close() + print("done inserting data") + + # check if log dir has more than 3 files + files = os.listdir(log_dir) + assert len(files) == 4, "expected 2 tfrecord files, and 2 json config files" + + # This will + stored_buffer = make_datastore( + log_dir, + capacity=200, + type=type, + ) + + print("total data size: ", len(stored_buffer)) + assert len(stored_buffer) == 15*10 - 10 # first 10 samples are ignored since + + +if __name__ == "__main__": + # print(" testing custom env") + env = CustomEnv() + run_rlds_logger(env) + + print("testing pendulum env") + env = gym.make("Pendulum-v1") + run_rlds_logger(env) + + # NOTE: trajectory buffer only support obs and action space + # of type array (not dict) + env = gym.make("HalfCheetah-v4") + run_rlds_logger(env, type="trajectory_buffer") + print("all tests passed") diff --git a/examples/async_learner_actor.py b/examples/async_learner_actor.py index bf17eae..a41a683 100644 --- a/examples/async_learner_actor.py +++ b/examples/async_learner_actor.py @@ -22,8 +22,9 @@ from edgeml.trainer import TrainerServer, TrainerClient, TrainerTunnel from edgeml.data.data_store import QueuedDataStore from edgeml.data.jaxrl_data_store import ReplayBufferDataStore +from edgeml.data.jaxrl_data_store import make_default_trajectory_buffer -from jaxrl_m_common import make_agent, make_trainer_config, make_wandb_logger, make_efficient_replay_buffer +from jaxrl_m_common import make_agent, make_trainer_config, make_wandb_logger FLAGS = flags.FLAGS @@ -53,16 +54,24 @@ flags.DEFINE_boolean("render", False, "Render the environment.") flags.DEFINE_string("ip", "localhost", "IP address of the learner.") -# experimental with efficient replay buffer +# experimental with trajectory buffer flags.DEFINE_boolean("use_traj_buffer", False, "Use efficient replay buffer.") +# save replaybuffer data as rlds tfrecord +flags.DEFINE_string("rlds_log_dir", None, "Directory to log data.") + + def print_green(x): return print("\033[92m {}\033[00m" .format(x)) + def print_yellow(x): return print("\033[93m {}\033[00m" .format(x)) ############################################################################## +global_rlds_logger = None + + def actor(agent: SACAgent, data_store, env, sampling_rng, tunnel=None): """ This is the actor loop, which runs when "--actor" is set to True. @@ -133,7 +142,7 @@ def update_params(params): if FLAGS.use_traj_buffer: # NOTE: end_of_trajectory is used in TrajectoryBuffer - data_payload["end_of_trajectory"] = done or truncated + data_payload["end_of_trajectory"] = truncated # TODO: check if ignore None is okay data_store.insert(data_payload) @@ -211,7 +220,7 @@ def stats_callback(type: str, payload: dict) -> dict: with timer.context("sample_replay_buffer"): if FLAGS.use_traj_buffer: batch, mask = replay_buffer.sample( - "training", # define in the TrajectoryBuffer.register_sample_config + "training", # define in the TrajectoryBuffer.register_sample_config FLAGS.batch_size, ) # replay_buffer's batch is default in cpu, put it to devices @@ -236,6 +245,51 @@ def stats_callback(type: str, payload: dict) -> dict: ############################################################################## +def create_datastore_and_wandb_logger(env): + """ + Utility function to create replay buffer and wandb logger. + """ + if FLAGS.rlds_log_dir: + print_yellow(f"Saving replay buffer data to {FLAGS.rlds_log_dir}") + # Install from: https://github.com/rail-berkeley/oxe_envlogger + from oxe_envlogger.rlds_logger import RLDSLogger + + logger = RLDSLogger( + observation_space=env.observation_space, + action_space=env.action_space, + dataset_name=FLAGS.env, + directory=FLAGS.rlds_log_dir, + max_episodes_per_file=10, # TODO: arbitrary number + ) + global global_rlds_logger + global_rlds_logger = logger + + if FLAGS.use_traj_buffer: + print_yellow(f"Using experimental Trajectory buffer") + replay_buffer = make_default_trajectory_buffer( + env.observation_space, + env.action_space, + capacity=FLAGS.replay_buffer_capacity, + rlds_logger=logger if FLAGS.rlds_log_dir else None, + ) + else: + replay_buffer = ReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=FLAGS.replay_buffer_capacity, + rlds_logger=logger if FLAGS.rlds_log_dir else None, + ) + + # set up wandb and logging + wandb_logger = make_wandb_logger( + project="jaxrl_minimal", + description=FLAGS.exp_name or FLAGS.env, + ) + return replay_buffer, wandb_logger + + +############################################################################## + def main(_): devices = jax.local_devices() num_devices = len(devices) @@ -263,30 +317,8 @@ def main(_): jax.tree_map(jnp.array, agent), sharding.replicate() ) - def create_replay_buffer_and_wandb_logger(): - if FLAGS.use_traj_buffer: - print_yellow(f"Using experimental Efficient replay buffer") - replay_buffer = make_efficient_replay_buffer( - env.observation_space, - env.action_space, - capacity=FLAGS.replay_buffer_capacity, - ) - else: - replay_buffer = ReplayBufferDataStore( - env.observation_space, - env.action_space, - capacity=FLAGS.replay_buffer_capacity, - ) - - # set up wandb and logging - wandb_logger = make_wandb_logger( - project="jaxrl_minimal", - description=FLAGS.exp_name or FLAGS.env, - ) - return replay_buffer, wandb_logger - if FLAGS.learner: - replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger() + replay_buffer, wandb_logger = create_datastore_and_wandb_logger(env) # learner loop print_green("starting learner loop") @@ -306,7 +338,7 @@ def create_replay_buffer_and_wandb_logger(): # In this example, the tunnel acts as the transport layer for the # trainerServer and trainerClient. Also, both actor and learner shares # the same replay buffer. - replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger() + replay_buffer, wandb_logger = create_datastore_and_wandb_logger(env) tunnel = TrainerTunnel() sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) @@ -325,4 +357,12 @@ def create_replay_buffer_and_wandb_logger(): if __name__ == "__main__": - app.run(main) + try: + app.run(main) + finally: + # NOTE: manually flush the logger when exit to prevent data loss + # this is required as the envlogger writer doesn't handle + # destruction of the object gracefully + if global_rlds_logger: + global_rlds_logger.close() + print_green("done exit") diff --git a/examples/jaxrl_m_common.py b/examples/jaxrl_m_common.py index f506529..29a9f7c 100644 --- a/examples/jaxrl_m_common.py +++ b/examples/jaxrl_m_common.py @@ -19,6 +19,7 @@ from edgeml.trainer import TrainerConfig from jax import nn +from oxe_envlogger.rlds_logger import RLDSLogger ############################################################################## @@ -75,56 +76,3 @@ def make_wandb_logger( variant={}, ) return wandb_logger - - -def make_efficient_replay_buffer( - observation_space: gym.Space, - action_space: gym.Space, - capacity: int, - device: Optional[jax.Device] = None, -): - replay_buffer = TrajectoryBufferDataStore( - capacity=capacity, - data_shapes=[ - DataShape("observations", observation_space.shape, np.float32), - DataShape("next_observations", observation_space.shape, np.float32), - DataShape("actions", action_space.shape, np.float32), - DataShape("rewards", (), np.float32), - DataShape("masks", (), np.float32), - DataShape("end_of_trajectory", (), dtype="bool"), - ], - min_trajectory_length=2, - device=device, - ) - - @jax.jit - def transform_rl_data(batch, mask): - batch_size = jax.tree_util.tree_flatten(batch)[0][0].shape[0] - chex.assert_tree_shape_prefix(batch["observations"], (batch_size, 2)) - chex.assert_tree_shape_prefix(mask["observations"], (batch_size, 2)) - return { - **batch, - "observations": batch["observations"][:, 0], - "next_observations": batch["observations"][:, 1], - }, { - **mask, - "observations": mask["observations"][:, 0], - "next_observations": mask["observations"][:, 1], - } - - replay_buffer.register_sample_config( - "training", - samplers={ - "observations": SequenceSampler( - squeeze=False, begin=0, end=2, source="observations" - ), - "actions": LatestSampler(), - "rewards": LatestSampler(), - "masks": LatestSampler(), - "next_observations": LatestSampler(), - "end_of_trajectory": LatestSampler(), - }, - transform=transform_rl_data, - sample_range=(0, 2), - ) - return replay_buffer