Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make env optional arg while creating from buffers #137

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1a9c761
make env optional arg while creating from buffers
avjmachine Aug 28, 2023
28403a7
Merge branch 'main' into env-optional-create-ds-buffers
avjmachine Sep 18, 2023
3bd0795
resolve review comments on PR#137
avjmachine Sep 18, 2023
3661142
Merge branch 'main' into env-optional-create-ds-buffers
avjmachine Oct 10, 2023
a4fdfa6
Merge branch 'main' into env-optional-create-ds-buffers
avjmachine Oct 19, 2023
f49b89f
reduce complexity in raise error, ignore pyright
avjmachine Oct 19, 2023
5835868
Merge branch 'main' into env-optional-create-ds-buffers
avjmachine Oct 22, 2023
2e12d14
handle env_spec, scores data when env is None
avjmachine Oct 22, 2023
1b50d31
make env_spec optional, handle cases when None
avjmachine Oct 22, 2023
20be8e8
modify test cases when env, env_spec is None
avjmachine Oct 22, 2023
5824961
fix pre-commit
younik Oct 23, 2023
eeaa6bf
add assert message
younik Oct 23, 2023
e4dd0ae
fix env_spec equality check for combining datasets
avjmachine Oct 27, 2023
8b35aff
refactor generate test dataset w/o env into common
avjmachine Oct 27, 2023
ec0d376
Merge branch 'main' into env-optional-create-ds-buffers
avjmachine Oct 30, 2023
d6566ec
fix method to validate datasets to combine
avjmachine Oct 30, 2023
6e5a4c3
remove redundant exception while recovering env when env_spec None
avjmachine Oct 30, 2023
993381e
recover action & observation space through gym.make from env as str |…
avjmachine Oct 30, 2023
7ce14bd
correct wrong position of assert env not None to when obs & action sp…
avjmachine Oct 30, 2023
938fdf7
correct args description for eval_env in create_dataset_from_buffers …
avjmachine Nov 4, 2023
9ca7af1
made warning checks for no env_spec and no eval_env_spec independent …
avjmachine Nov 4, 2023
34b8a0a
fix and optimize validating combine datasets
avjmachine Nov 4, 2023
ca21ecd
Merge branch 'main' into env-optional-create-ds-buffers
younik Nov 16, 2023
f7d05ef
fix, bypass pyright precommit errors
avjmachine Nov 16, 2023
62d84a8
assert env is a gym env before accessing obs & action space, to avoid…
avjmachine Nov 16, 2023
4d1c60e
update pre-commit
younik Nov 16, 2023
90591f3
add fix to pre-commit
younik Nov 17, 2023
d6a8fd3
fix pre-commit
younik Nov 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.11'
- run: pip install .
- run: pip install pre-commit
- run: pre-commit --version
- run: pre-commit install
Expand Down
16 changes: 11 additions & 5 deletions minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int]:

@dataclass
class MinariDatasetSpec:
env_spec: EnvSpec
env_spec: Optional[EnvSpec]
total_episodes: int
total_steps: np.int64
dataset_id: str
Expand Down Expand Up @@ -99,9 +99,11 @@ def __init__(

metadata = self._data.metadata

env_spec = metadata["env_spec"]
assert isinstance(env_spec, str)
self._env_spec = EnvSpec.from_json(env_spec)
env_spec = metadata.get("env_spec")
if env_spec is not None:
assert isinstance(env_spec, str)
env_spec = EnvSpec.from_json(env_spec)
self._env_spec = env_spec

eval_env_spec = metadata.get("eval_env_spec")
if eval_env_spec is not None:
Expand Down Expand Up @@ -150,7 +152,11 @@ def recover_environment(self, eval_env: bool = False, **kwargs) -> gym.Env:
logger.info(
f"`eval_env` has been set to True but the dataset {self._dataset_id} doesn't provide an evaluation environment. Instead, the environment used for collecting the data will be returned: {self._env_spec}"
)
return gym.make(self._env_spec, **kwargs)

if self.env_spec is None:
raise ValueError("Environment cannot be recovered when env_spec is None")

return gym.make(self.env_spec, **kwargs)

def set_seed(self, seed: int):
"""Set seed for random episode sampling generator."""
Expand Down
14 changes: 7 additions & 7 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def update_episodes(self, episodes: Iterable[dict]):
_add_episode_to_group(eps_buff, episode_group)

current_steps = file.attrs["total_steps"]
assert isinstance(current_steps, np.int64)
assert isinstance(current_steps, np.integer)
total_steps = current_steps + additional_steps
total_episodes = len(file.keys())

Expand All @@ -273,7 +273,7 @@ def update_from_storage(self, storage: MinariStorage):

with h5py.File(self._file_path, "a", track_order=True) as file:
last_episode_id = file.attrs["total_episodes"]
assert isinstance(last_episode_id, np.int64)
assert isinstance(last_episode_id, np.integer)
storage_total_episodes = storage.total_episodes

for id in range(storage.total_episodes):
Expand All @@ -294,7 +294,7 @@ def update_from_storage(self, storage: MinariStorage):
"total_episodes", last_episode_id + storage_total_episodes
)
total_steps = file.attrs["total_steps"]
assert isinstance(total_steps, np.int64)
assert isinstance(total_steps, np.integer)
file.attrs.modify("total_steps", total_steps + storage.total_steps)

storage_metadata = storage.metadata
Expand All @@ -316,19 +316,19 @@ def data_path(self) -> PathLike:
return os.path.dirname(self._file_path)

@property
def total_episodes(self) -> np.int64:
def total_episodes(self) -> np.integer:
"""Total episodes in the dataset."""
with h5py.File(self._file_path, "r") as file:
total_episodes = file.attrs["total_episodes"]
assert isinstance(total_episodes, np.int64)
assert isinstance(total_episodes, np.integer)
return total_episodes

@property
def total_steps(self) -> np.int64:
def total_steps(self) -> np.integer:
"""Total steps in the dataset."""
with h5py.File(self._file_path, "r") as file:
total_steps = file.attrs["total_steps"]
assert isinstance(total_steps, np.int64)
assert isinstance(total_steps, np.integer)
return total_steps

@property
Expand Down
102 changes: 65 additions & 37 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def combine_minari_version_specifiers(specifier_set: SpecifierSet) -> SpecifierS
return final_version_specifier


def validate_datasets_to_combine(datasets_to_combine: List[MinariDataset]) -> EnvSpec:
def validate_datasets_to_combine(
datasets_to_combine: List[MinariDataset],
) -> EnvSpec | None:
"""Check if the given datasets can be combined.

Tests if the datasets were created with the same environment (`env_spec`) and re-calculates the
Expand All @@ -163,36 +165,42 @@ def validate_datasets_to_combine(datasets_to_combine: List[MinariDataset]) -> En

Returns:
combined_dataset_env_spec (EnvSpec): the resulting EnvSpec of combining the MinariDatasets

"""
assert all(isinstance(dataset, MinariDataset) for dataset in datasets_to_combine), f"Some of the datasets to combine are not of type {MinariDataset}"

# Check if there are any `None` max_episode_steps
if any(
(dataset.spec.env_spec.max_episode_steps is None)
for dataset in datasets_to_combine
):
max_episode_steps = None
else:
max_episode_steps = max(
dataset.spec.env_spec.max_episode_steps for dataset in datasets_to_combine
)
# get first among the dataset's env_spec which is not None
first_not_none_env_spec = next((dataset.spec.env_spec for dataset in datasets_to_combine if dataset.spec.env_spec is not None), None)

combine_env_spec = []
for dataset in datasets_to_combine:
dataset_env_spec = copy.deepcopy(dataset.spec.env_spec)
dataset_env_spec.max_episode_steps = max_episode_steps
combine_env_spec.append(dataset_env_spec)
# early return where all datasets have no env_spec
if first_not_none_env_spec is None:
return None

common_env_spec = copy.deepcopy(first_not_none_env_spec)

assert all(
env_spec == combine_env_spec[0] for env_spec in combine_env_spec
), "The datasets to be combined have different values for `env_spec` attribute."
# updating the common_env_spec's max_episode_steps & checking equivalence of all env specs
for dataset in datasets_to_combine:
assert isinstance(dataset, MinariDataset)
env_spec = dataset.spec.env_spec
if env_spec is not None:
younik marked this conversation as resolved.
Show resolved Hide resolved
if (
common_env_spec.max_episode_steps is None
or env_spec.max_episode_steps is None
):
common_env_spec.max_episode_steps = None
else:
common_env_spec.max_episode_steps = max(
common_env_spec.max_episode_steps, env_spec.max_episode_steps
)
younik marked this conversation as resolved.
Show resolved Hide resolved
# setting max_episode_steps in object's copy to same value for sake of checking equality
env_spec_copy = copy.deepcopy(env_spec)
env_spec_copy.max_episode_steps = common_env_spec.max_episode_steps
if env_spec_copy != common_env_spec:
raise ValueError(
"The datasets to be combined have different values for `env_spec` attribute."
)
else:
raise ValueError("Cannot combine datasets having env_spec with those having no env_spec.")

# Check that all datasets have the same action/observation space
if any(dataset.action_space != datasets_to_combine[0].action_space for dataset in datasets_to_combine):
raise ValueError("The datasets to combine must have the same action space.")
if any(dataset.observation_space != datasets_to_combine[0].observation_space for dataset in datasets_to_combine):
raise ValueError("The datasets to combine must have the same observation space.")
return combine_env_spec[0]
return common_env_spec


class RandomPolicy:
Expand Down Expand Up @@ -301,7 +309,7 @@ def get_average_reference_score(
for _ in range(num_episodes):
while True:
action = policy(obs)
obs, _, terminated, truncated, info = env.step(action)
obs, _, terminated, truncated, info = env.step(action) # pyright: ignore[reportGeneralTypeIssues]
if terminated or truncated:
episode_returns.append(info["episode"]["r"])
obs, _ = env.reset()
Expand All @@ -326,8 +334,8 @@ def _generate_dataset_path(dataset_id: str) -> str | os.PathLike:


def _generate_dataset_metadata(
env_spec: EnvSpec,
dataset_id: str,
env_spec: Optional[EnvSpec],
eval_env: Optional[str | gym.Env | EnvSpec],
algorithm_name: Optional[str],
author: Optional[str],
Expand Down Expand Up @@ -403,7 +411,7 @@ def _generate_dataset_metadata(
if eval_env is None:
warnings.warn(
f"`eval_env` is set to None. If another environment is intended to be used for evaluation please specify corresponding Gymnasium environment (gym.Env | gym.envs.registration.EnvSpec).\
If None the environment used to collect the data (`env={env_spec}`) will be used for this purpose.",
If None the environment used to collect the data (`env={env_spec}`) will be used for this purpose.",
UserWarning,
)
eval_env_spec = env_spec
Expand All @@ -421,7 +429,13 @@ def _generate_dataset_metadata(
assert eval_env_spec is not None
dataset_metadata["eval_env_spec"] = eval_env_spec.to_json()

if expert_policy is not None or ref_max_score is not None:
if env_spec is None:
warnings.warn(
"env_spec is None, no environment spec is provided during collection for this dataset",
UserWarning,
)

if eval_env_spec is not None and (expert_policy is not None or ref_max_score is not None):
env_ref_score = gym.make(eval_env_spec)
if ref_min_score is None:
ref_min_score = get_average_reference_score(
Expand All @@ -441,8 +455,8 @@ def _generate_dataset_metadata(

def create_dataset_from_buffers(
dataset_id: str,
env: str | gym.Env | EnvSpec,
buffer: List[Dict[str, Union[list, Dict]]],
env: Optional[str | gym.Env | EnvSpec] = None,
eval_env: Optional[str | gym.Env | EnvSpec] = None,
algorithm_name: Optional[str] = None,
author: Optional[str] = None,
Expand Down Expand Up @@ -473,10 +487,10 @@ def create_dataset_from_buffers(

Args:
dataset_id (str): name id to identify Minari dataset.
env (str|gym.Env|EnvSpec): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) used to collect the buffer data.
buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data.
env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) used to collect the buffer data. Defaults to None.
eval_env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) to use for evaluation with the dataset. After loading the dataset, the environment can be recovered as follows: `MinariDataset.recover_environment(eval_env=True).
If None the `env` used to collect the buffer data should be used for evaluation.
If None, and if the `env` used to collect the buffer data is available, latter will be used for evaluation.
algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None.
author (Optional[str], optional): author that generated the dataset. Defaults to None.
author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None.
Expand All @@ -503,11 +517,25 @@ def create_dataset_from_buffers(
env_spec = env
elif isinstance(env, gym.Env):
env_spec = env.spec
elif env is None:
if observation_space is None or action_space is None:
raise ValueError("Both observation space and action space must be provided, if env is None")
env_spec = None
else:
raise ValueError("The `env` argument must be of types str|EnvSpec|gym.Env")
raise ValueError("The `env` argument must be of types str|EnvSpec|gym.Env|None")

if isinstance(env, (str, EnvSpec)):
env = gym.make(env)
if observation_space is None:
assert isinstance(env, gym.Env)
observation_space = env.observation_space
if action_space is None:
assert isinstance(env, gym.Env)
action_space = env.action_space

metadata = _generate_dataset_metadata(
env_spec,
dataset_id,
env_spec,
eval_env,
algorithm_name,
author,
Expand Down Expand Up @@ -575,8 +603,8 @@ def create_dataset_from_collector_env(
assert collector_env.datasets_path is not None
dataset_path = _generate_dataset_path(dataset_id)
metadata: Dict[str, Any] = _generate_dataset_metadata(
copy.deepcopy(collector_env.env.spec),
dataset_id,
copy.deepcopy(collector_env.env.spec),
eval_env,
algorithm_name,
author,
Expand Down
64 changes: 63 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import sys
import unicodedata
from typing import Any, Iterable, List, Optional, Union
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -627,3 +629,63 @@ def check_episode_data_integrity(
assert episode.total_timesteps == len(episode.rewards)
assert episode.total_timesteps == len(episode.terminations)
assert episode.total_timesteps == len(episode.truncations)


def _space_subset_helper(entry: Dict):

return OrderedDict(
{
"component_2": OrderedDict(
{"subcomponent_2": entry["component_2"]["subcomponent_2"]}
)
}
)


def get_sample_buffer_for_dataset_from_env(env, num_episodes=10):

buffer = []
observations = []
actions = []
rewards = []
terminations = []
truncations = []

observation, info = env.reset(seed=42)

# Step the environment, DataCollectorV0 wrapper will do the data collection job
observation, _ = env.reset()
observations.append(_space_subset_helper(observation))
for episode in range(num_episodes):
terminated = False
truncated = False

while not terminated and not truncated:
action = env.action_space.sample() # User-defined policy function
observation, reward, terminated, truncated, _ = env.step(action)
observations.append(_space_subset_helper(observation))
actions.append(_space_subset_helper(action))
rewards.append(reward)
terminations.append(terminated)
truncations.append(truncated)

episode_buffer = {
"observations": copy.deepcopy(observations),
"actions": copy.deepcopy(actions),
"rewards": np.asarray(rewards),
"terminations": np.asarray(terminations),
"truncations": np.asarray(truncations),
}

buffer.append(episode_buffer)

observations.clear()
actions.clear()
rewards.clear()
terminations.clear()
truncations.clear()

observation, _ = env.reset()
observations.append(_space_subset_helper(observation))

return buffer
Loading