Skip to content

Commit

Permalink
support tfrecord for datastore (#15)
Browse files Browse the repository at this point in the history
* simple exporter wip

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* commit files

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* cleaner impl

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* use oxe_logger for rlds recording

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* with make_datastore method

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* minor cleanup and comments

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* nit setup.py

Signed-off-by: youliang <tan_you_liang@hotmail.com>

* try fix ci

Signed-off-by: youliang <tan_you_liang@hotmail.com>

---------

Signed-off-by: youliang <tan_you_liang@hotmail.com>
  • Loading branch information
youliangtan authored Dec 29, 2023
1 parent 3199025 commit 52929ba
Show file tree
Hide file tree
Showing 11 changed files with 489 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
# Install the current repo (edgeml)
- name: Install edgeml
run: |
pip install .
pip install -e .
- name: Lint with flake8
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ __pycache__
*.so
*.pyc
wandb/
logs/
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 You Liang Tan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,20 @@ python3 examples/run_data_store.py
```bash
# Indicate --learner or --actor mode, no tag means async multithreaded mode
# provide --ip for remote connection, default is localhost
# use --render to render the gym env
# use --use_traj_buffer to use trajectory buffer instead of replay buffer
python3 examples/async_learner_actor.py
```

**More option flags**:
- `--env`: gym env name, default is `HalfCheetah-v4`
- `--learner`: run learner mode
- `--actor`: run actor mode
- `--ip`: ip address of the remote server
- `--render`: render the gym env
- `--use_traj_buffer`: use trajectory buffer instead of replay buffer
- `--rlds_log_dir`: directory to save the tfrecords for [RLDS](https://github.com/google-research/rlds)

NOTE: rlds logger requires installation of [oxe_envlogger](https://github.com/rail-berkeley/oxe_envlogger)

---

## Architecture
Expand Down
136 changes: 131 additions & 5 deletions edgeml/data/jaxrl_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,27 @@
# NOTE: this requires jaxrl_m to be installed:
# https://github.com/rail-berkeley/jaxrl_minimal

from __future__ import annotations

from threading import Lock
from typing import List, Optional
from edgeml.data.data_store import DataStoreBase
from edgeml.data.trajectory_buffer import TrajectoryBuffer, DataShape
from edgeml.data.sampler import LatestSampler, SequenceSampler

import gym
import jax
import chex
import numpy as np
from jaxrl_m.data.replay_buffer import ReplayBuffer
import tensorflow as tf

# import oxe_envlogger if it is installed
try:
from oxe_envlogger.rlds_logger import RLDSLogger, RLDSStepType
except ImportError:
print("rlds logger is not installed, install it if required: "
"https://github.com/rail-berkeley/oxe_envlogger ")


##############################################################################
Expand All @@ -22,6 +35,7 @@ def __init__(
device: Optional[jax.Device] = None,
seed: int = 0,
min_trajectory_length: int = 2,
rlds_logger: Optional[RLDSLogger] = None,
):
TrajectoryBuffer.__init__(
self,
Expand All @@ -33,11 +47,36 @@ def __init__(
)
DataStoreBase.__init__(self, capacity)
self._lock = Lock()
self._logger = None

if rlds_logger:
self.step_type = RLDSStepType.TERMINATION # to init the state for restart
self._logger = rlds_logger

# ensure thread safety
def insert(self, *args, **kwargs):
def insert(self, data):
with self._lock:
super(TrajectoryBufferDataStore, self).insert(*args, **kwargs)
super(TrajectoryBufferDataStore, self).insert(data)

if self._logger:
# handle restart when it was done before
if self.step_type == RLDSStepType.TERMINATION:
self.step_type = RLDSStepType.RESTART
elif self.step_type == RLDSStepType.TRUNCATION:
self.step_type = RLDSStepType.RESTART
elif not data["masks"]: # 0 is done, 1 is not done
self.step_type = RLDSStepType.TERMINATION
elif data["end_of_trajectory"]:
self.step_type = RLDSStepType.TRUNCATION
else:
self.step_type = RLDSStepType.TRANSITION

self._logger(
action=data["actions"],
obs=data["next_observations"], # TODO: not obs, but next_obs
reward=data["rewards"],
step_type=self.step_type,
)

# ensure thread safety
def sample(self, *args, **kwargs):
Expand All @@ -59,15 +98,41 @@ def __init__(
observation_space: gym.Space,
action_space: gym.Space,
capacity: int,
rlds_logger: Optional[RLDSLogger] = None,
):
ReplayBuffer.__init__(self, observation_space, action_space, capacity)
DataStoreBase.__init__(self, capacity)
self._insert_seq_id = 0 # keeps increasing
self._lock = Lock()
self._logger = None

if rlds_logger:
self.step_type = RLDSStepType.TERMINATION # to init the state for restart
self._logger = rlds_logger

# ensure thread safety
def insert(self, *args, **kwargs):
def insert(self, data):
with self._lock:
super(ReplayBufferDataStore, self).insert(*args, **kwargs)
super(ReplayBufferDataStore, self).insert(data)
self._insert_seq_id += 1

# add data to the rlds logger
# TODO: the current impl of ReplayBuffer doesn't support
# proper truncation of the trajectory
if self._logger:
if self.step_type == RLDSStepType.TERMINATION:
self.step_type = RLDSStepType.RESTART
elif not data["masks"]: # 0 is done, 1 is not done
self.step_type = RLDSStepType.TERMINATION
else:
self.step_type = RLDSStepType.TRANSITION

self._logger(
action=data["actions"],
obs=data["next_observations"], # TODO: check if this is correct
reward=data["rewards"],
step_type=self.step_type,
)

# ensure thread safety
def sample(self, *args, **kwargs):
Expand All @@ -76,8 +141,69 @@ def sample(self, *args, **kwargs):

# NOTE: method for DataStoreBase
def latest_data_id(self):
return self._insert_index
return self._insert_seq_id

# NOTE: method for DataStoreBase
def get_latest_data(self, from_id: int):
raise NotImplementedError # TODO

def __del__(self):
if self._logger:
self._logger.close()

##############################################################################


def make_default_trajectory_buffer(
observation_space: gym.Space,
action_space: gym.Space,
capacity: int,
device: Optional[jax.Device] = None,
rlds_logger: Optional[RLDSLogger] = None,
):
replay_buffer = TrajectoryBufferDataStore(
capacity=capacity,
data_shapes=[
DataShape("observations", observation_space.shape, observation_space.dtype),
DataShape("next_observations", observation_space.shape, observation_space.dtype),
DataShape("actions", action_space.shape, action_space.dtype),
DataShape("rewards", (), np.float64),
DataShape("masks", (), np.float64),
DataShape("end_of_trajectory", (), dtype="bool"),
],
min_trajectory_length=2,
device=device,
rlds_logger=rlds_logger,
)

@jax.jit
def transform_rl_data(batch, mask):
batch_size = jax.tree_util.tree_flatten(batch)[0][0].shape[0]
chex.assert_tree_shape_prefix(batch["observations"], (batch_size, 2))
chex.assert_tree_shape_prefix(mask["observations"], (batch_size, 2))
return {
**batch,
"observations": batch["observations"][:, 0],
"next_observations": batch["observations"][:, 1],
}, {
**mask,
"observations": mask["observations"][:, 0],
"next_observations": mask["observations"][:, 1],
}

replay_buffer.register_sample_config(
"training",
samplers={
"observations": SequenceSampler(
squeeze=False, begin=0, end=2, source="observations"
),
"actions": LatestSampler(),
"rewards": LatestSampler(),
"masks": LatestSampler(),
"next_observations": LatestSampler(),
"end_of_trajectory": LatestSampler(),
},
transform=transform_rl_data,
sample_range=(0, 2),
)
return replay_buffer
116 changes: 116 additions & 0 deletions edgeml/data/tfds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

import gym
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from edgeml.data.jaxrl_data_store import ReplayBufferDataStore
from edgeml.data.jaxrl_data_store import make_default_trajectory_buffer

##############################################################################


def make_datastore(dataset_dir, capacity, type="replay_buffer"):
"""
Load an RLDS dataset from the specified directory and populate it
into the given datastore.
Args:
- dataset_dir: Directory where the RLDS dataset is stored.
- capacity: Capacity of the replay buffer.
- type: supported types are "replay_buffer" and "trajectory_buffer"
Returns:
- replay_buffer: Replay buffer populated with the RLDS dataset.
"""
# Load the dataset
dataset = tfds.builder_from_directory(dataset_dir).as_dataset(split='all')

# get the obs and action spec from the dataset
obs_tensor_spec = dataset.element_spec["steps"].element_spec["observation"]
action_tensor_spec = dataset.element_spec['steps'].element_spec['action']

print("obs spec: ", obs_tensor_spec)
print("action spec: ", action_tensor_spec)

if type == "replay_buffer":
datastore = ReplayBufferDataStore(
observation_space=tensor_spec_to_gym_space(obs_tensor_spec),
action_space=tensor_spec_to_gym_space(action_tensor_spec),
capacity=capacity,
)
elif type == "trajectory_buffer":
datastore = make_default_trajectory_buffer(
observation_space=tensor_spec_to_gym_space(obs_tensor_spec),
action_space=tensor_spec_to_gym_space(action_tensor_spec),
capacity=capacity,
)
else:
raise ValueError(f"Unsupported type: {type}")

# Iterate over episodes in the dataset
for episode in dataset:
steps = episode['steps']
obs = None
# Iterate through steps in the episode
for i, step in enumerate(steps):
if i == 0:
obs = get_numpy_from_tensor(step['observation'])
continue

# Extract relevant data from the step
next_obs = get_numpy_from_tensor(step['observation'])
action = get_numpy_from_tensor(step['action'])
reward = step.get('reward', 0).numpy() # Defaulting to 0 if 'reward' key is missing
terminate = step['is_terminal'].numpy() # or is_last
truncate = step['is_last'].numpy() # truncate is not avail in the ReplayBuffer

data = dict(
observations=obs,
next_observations=next_obs,
actions=action,
rewards=reward,
masks=1 - terminate, # 1 is transition, 0 is terminal
)

if type == "trajectory_buffer":
data["end_of_trajectory"] = truncate

# Insert data into the replay buffer
datastore.insert(data)
obs = next_obs
return datastore

##############################################################################


def tensor_spec_to_gym_space(tensor_spec: tf.data.experimental.TensorSpec):
"""
Convert a TensorSpec to a gym.Space, should support dict and box
"""
if isinstance(tensor_spec, tf.TensorSpec):
return gym.spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=tensor_spec.shape,
dtype=tensor_spec.dtype.as_numpy_dtype,
)
elif isinstance(tensor_spec, dict):
return gym.spaces.Dict(
{
k: tensor_spec_to_gym_space(v)
for k, v in tensor_spec.items()
}
)
else:
raise TypeError(f"Unsupported tensor spec type: {type(tensor_spec)}")


def get_numpy_from_tensor(tensor):
"""
Convert a tensor to numpy
"""
if isinstance(tensor, dict):
return {k: get_numpy_from_tensor(v) for k, v in tensor.items()}
return tensor.numpy()
13 changes: 11 additions & 2 deletions edgeml/data/trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,19 @@ def register_sample_config(
def insert(self, data: Dict[str, jax.Array], end_of_trajectory: bool = False):
"""
Insert a single data point into the data store.
# TODO: end_of_trajectory tag defined in data?
the min required data entry is:
data = dict(
observations=observation_data, # np.ndarray or dict
next_observations=next_observation_data, # np.ndarray or dict
actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype),
rewards=np.empty((capacity,), dtype=np.float32),
masks=np.empty((capacity,), dtype=bool), # is terminal
end_of_trajectory=False, # is last
)
"""
end_of_trajectory = data.get(
"end_of_trajectory", end_of_trajectory) # TODO
"end_of_trajectory", end_of_trajectory) # TODO: if not exist, assume False

# Grab the metadata of the sample we're overwriting
real_insert_idx = self._insert_idx % self.capacity
Expand Down
1 change: 1 addition & 0 deletions edgeml/internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def make_compression_method(compression: str) -> Tuple[Callable, Callable]:
:return: compress, decompress functions
def compress(object) -> bytes
def decompress(data) -> object
TODO: support msgpack
"""
if compression == 'lz4':
def compress(data): return lz4.frame.compress(pickle.dumps(data))
Expand Down
Loading

0 comments on commit 52929ba

Please sign in to comment.