Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
im-Kitsch committed Sep 11, 2023
1 parent 12c8cb3 commit 551106e
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions tests/utils/test_dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,106 @@ def test_generate_dataset_with_space_subset_external_buffer():
env.close()

check_load_and_delete_dataset(dataset_id)


def test_generate_dataset_with_buffer_incomplete_traj():
"""Test create dataset from external buffers with incomplete trajectories."""
dataset_id = "cartpole-test-v0"
env_id = "CartPole-v1"

# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)
env.reset(seed=42)

obs_all, act_all, rew_all, term_all, trunc_all = [], [], [], [], []
buffer = []
num_episodes = 10
# Step the environment, DataCollectorV0 wrapper will do the data collection job
for episode in range(num_episodes):
observations, actions, rewards, terminations, truncations = [], [], [], [], []

observation, _ = env.reset()
observations.append(observation)
_term_i, _trunc_i = False, False

while not _term_i and not _trunc_i:
_act_i = env.action_space.sample() # User-defined policy function
_obs_i, _rwd_i, _term_i, _trunc_i, _ = env.step(_act_i)
observations, actions, rewards, terminations, truncations = map(
lambda x, y: x + [y],
[observations, actions, rewards, terminations, truncations],
[_obs_i, _act_i, _rwd_i, _term_i, _trunc_i],
)

# last episoode manually change the last truncation and termination to False to verify
if episode == num_episodes - 1:
terminations[-1] = False
truncations[-1] = False

obs_all, act_all, rew_all, term_all, trunc_all = map(
lambda x, y: x + [np.array(y)],
[obs_all, act_all, rew_all, term_all, trunc_all],
[observations, actions, rewards, terminations, truncations],
)

buffer.append(
{
"observations": observations,
"actions": actions,
"rewards": rewards,
"terminations": terminations,
"truncations": truncations,
}
)

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_buffers(
dataset_id=dataset_id,
env=env,
buffer=buffer,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="wdudley@farama.org",
)

assert isinstance(dataset, MinariDataset)
assert dataset.total_episodes == num_episodes
assert dataset.spec.total_episodes == num_episodes
assert len(dataset.episode_indices) == num_episodes

check_data_integrity(dataset._data, dataset.episode_indices)
check_env_recovery(env, dataset)
env.close()

dataset_loaded = minari.load_dataset(dataset_id)
obs_loaded, act_loaded, rew_loaded, term_loaded, trunc_loaded = [], [], [], [], []
for _eps in dataset_loaded:
obs_loaded.append(_eps.observations)
act_loaded.append(_eps.actions)
rew_loaded.append(_eps.rewards)
term_loaded.append(_eps.terminations)
trunc_loaded.append(_eps.truncations)

obs_loaded, act_loaded, rew_loaded, term_loaded, trunc_loaded = map(
lambda x: np.concatenate(x),
[obs_loaded, act_loaded, rew_loaded, term_loaded, trunc_loaded],
)
obs_original, act_original, rew_original, term_original, trunc_original = map(
lambda x: np.concatenate(x), [obs_all, act_all, rew_all, term_all, trunc_all]
)

assert np.all(obs_loaded == obs_original)
assert np.all(act_loaded == act_original)
assert np.all(rew_loaded == rew_original)
assert np.all(term_loaded == term_original)
assert np.all(trunc_loaded[:-1] == trunc_original[-1])
assert trunc_loaded[-1].item() is True
assert trunc_original[-1].item() is False

check_load_and_delete_dataset(dataset_id)
return

0 comments on commit 551106e

Please sign in to comment.