-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support tfrecord for datastore (#15)
* simple exporter wip Signed-off-by: youliang <tan_you_liang@hotmail.com> * commit files Signed-off-by: youliang <tan_you_liang@hotmail.com> * cleaner impl Signed-off-by: youliang <tan_you_liang@hotmail.com> * use oxe_logger for rlds recording Signed-off-by: youliang <tan_you_liang@hotmail.com> * with make_datastore method Signed-off-by: youliang <tan_you_liang@hotmail.com> * minor cleanup and comments Signed-off-by: youliang <tan_you_liang@hotmail.com> * nit setup.py Signed-off-by: youliang <tan_you_liang@hotmail.com> * try fix ci Signed-off-by: youliang <tan_you_liang@hotmail.com> --------- Signed-off-by: youliang <tan_you_liang@hotmail.com>
- Loading branch information
1 parent
3199025
commit 52929ba
Showing
11 changed files
with
489 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ __pycache__ | |
*.so | ||
*.pyc | ||
wandb/ | ||
logs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 You Liang Tan | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from __future__ import annotations | ||
|
||
import gym | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
from edgeml.data.jaxrl_data_store import ReplayBufferDataStore | ||
from edgeml.data.jaxrl_data_store import make_default_trajectory_buffer | ||
|
||
############################################################################## | ||
|
||
|
||
def make_datastore(dataset_dir, capacity, type="replay_buffer"): | ||
""" | ||
Load an RLDS dataset from the specified directory and populate it | ||
into the given datastore. | ||
Args: | ||
- dataset_dir: Directory where the RLDS dataset is stored. | ||
- capacity: Capacity of the replay buffer. | ||
- type: supported types are "replay_buffer" and "trajectory_buffer" | ||
Returns: | ||
- replay_buffer: Replay buffer populated with the RLDS dataset. | ||
""" | ||
# Load the dataset | ||
dataset = tfds.builder_from_directory(dataset_dir).as_dataset(split='all') | ||
|
||
# get the obs and action spec from the dataset | ||
obs_tensor_spec = dataset.element_spec["steps"].element_spec["observation"] | ||
action_tensor_spec = dataset.element_spec['steps'].element_spec['action'] | ||
|
||
print("obs spec: ", obs_tensor_spec) | ||
print("action spec: ", action_tensor_spec) | ||
|
||
if type == "replay_buffer": | ||
datastore = ReplayBufferDataStore( | ||
observation_space=tensor_spec_to_gym_space(obs_tensor_spec), | ||
action_space=tensor_spec_to_gym_space(action_tensor_spec), | ||
capacity=capacity, | ||
) | ||
elif type == "trajectory_buffer": | ||
datastore = make_default_trajectory_buffer( | ||
observation_space=tensor_spec_to_gym_space(obs_tensor_spec), | ||
action_space=tensor_spec_to_gym_space(action_tensor_spec), | ||
capacity=capacity, | ||
) | ||
else: | ||
raise ValueError(f"Unsupported type: {type}") | ||
|
||
# Iterate over episodes in the dataset | ||
for episode in dataset: | ||
steps = episode['steps'] | ||
obs = None | ||
# Iterate through steps in the episode | ||
for i, step in enumerate(steps): | ||
if i == 0: | ||
obs = get_numpy_from_tensor(step['observation']) | ||
continue | ||
|
||
# Extract relevant data from the step | ||
next_obs = get_numpy_from_tensor(step['observation']) | ||
action = get_numpy_from_tensor(step['action']) | ||
reward = step.get('reward', 0).numpy() # Defaulting to 0 if 'reward' key is missing | ||
terminate = step['is_terminal'].numpy() # or is_last | ||
truncate = step['is_last'].numpy() # truncate is not avail in the ReplayBuffer | ||
|
||
data = dict( | ||
observations=obs, | ||
next_observations=next_obs, | ||
actions=action, | ||
rewards=reward, | ||
masks=1 - terminate, # 1 is transition, 0 is terminal | ||
) | ||
|
||
if type == "trajectory_buffer": | ||
data["end_of_trajectory"] = truncate | ||
|
||
# Insert data into the replay buffer | ||
datastore.insert(data) | ||
obs = next_obs | ||
return datastore | ||
|
||
############################################################################## | ||
|
||
|
||
def tensor_spec_to_gym_space(tensor_spec: tf.data.experimental.TensorSpec): | ||
""" | ||
Convert a TensorSpec to a gym.Space, should support dict and box | ||
""" | ||
if isinstance(tensor_spec, tf.TensorSpec): | ||
return gym.spaces.Box( | ||
low=np.finfo(np.float32).min, | ||
high=np.finfo(np.float32).max, | ||
shape=tensor_spec.shape, | ||
dtype=tensor_spec.dtype.as_numpy_dtype, | ||
) | ||
elif isinstance(tensor_spec, dict): | ||
return gym.spaces.Dict( | ||
{ | ||
k: tensor_spec_to_gym_space(v) | ||
for k, v in tensor_spec.items() | ||
} | ||
) | ||
else: | ||
raise TypeError(f"Unsupported tensor spec type: {type(tensor_spec)}") | ||
|
||
|
||
def get_numpy_from_tensor(tensor): | ||
""" | ||
Convert a tensor to numpy | ||
""" | ||
if isinstance(tensor, dict): | ||
return {k: get_numpy_from_tensor(v) for k, v in tensor.items()} | ||
return tensor.numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.