Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Sep 3, 2023
1 parent 8d8f799 commit e6355ce
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
8 changes: 8 additions & 0 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def total_episodes(self) -> np.int64:
assert type(total_episodes) == np.int64
return total_episodes

@property
def total_steps(self) -> np.int64:
"""Total steps in the dataset."""
with h5py.File(self.data_path, "r") as file:
total_episodes = file.attrs["total_steps"]
assert type(total_episodes) == np.int64
return total_episodes

def get_h5py_subgroup(group: h5py.Group, name: str):
if name in group:
subgroup = group.get(name)
Expand Down
12 changes: 8 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,26 +461,30 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]):
# verify we have the right number of episodes, available at the right indices
assert data.total_episodes == len(episodes)
total_steps = 0

observation_space = data.metadata["observation_space"]
action_space = data.metadata["action_space"]

# verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct
for episode in episodes:
total_steps += episode["total_timesteps"]
_check_space_elem(
episode["observations"],
data.observation_space,
observation_space,
episode["total_timesteps"] + 1,
)
_check_space_elem(
episode["actions"], data.action_space, episode["total_timesteps"]
episode["actions"], action_space, episode["total_timesteps"]
)

for i in range(episode["total_timesteps"] + 1):
obs = _reconstuct_obs_or_action_at_index_recursive(
episode["observations"], i
)
assert data.observation_space.contains(obs)
assert observation_space.contains(obs)
for i in range(episode["total_timesteps"]):
action = _reconstuct_obs_or_action_at_index_recursive(episode["actions"], i)
assert data.action_space.contains(action)
assert action_space.contains(action)

assert episode["total_timesteps"] == len(episode["rewards"])
assert episode["total_timesteps"] == len(episode["terminations"])
Expand Down
10 changes: 3 additions & 7 deletions tests/utils/test_dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,11 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id):
env.reset(seed=42)

for episode in range(num_episodes):
terminated = False
truncated = False
while not terminated and not truncated:
done = False
while not done:
action = env.action_space.sample() # User-defined policy function
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
assert not env._buffer[-1]
else:
assert env._buffer[-1]
done = terminated or truncated

env.reset()

Expand Down

0 comments on commit e6355ce

Please sign in to comment.