Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Oct 23, 2023
1 parent 20be8e8 commit 1a61d32
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
45 changes: 20 additions & 25 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def combine_minari_version_specifiers(specifier_set: SpecifierSet) -> SpecifierS
return final_version_specifier


def validate_datasets_to_combine(datasets_to_combine: List[MinariDataset]) -> EnvSpec:
def validate_datasets_to_combine(
datasets_to_combine: List[MinariDataset],
) -> EnvSpec | None:
"""Check if the given datasets can be combined.
Tests if the datasets were created with the same environment (`env_spec`) and re-calculates the
Expand All @@ -162,30 +164,24 @@ def validate_datasets_to_combine(datasets_to_combine: List[MinariDataset]) -> En
Returns:
combined_dataset_env_spec (EnvSpec): the resulting EnvSpec of combining the MinariDatasets
"""
assert all(isinstance(dataset, MinariDataset) for dataset in datasets_to_combine)

# Check if there are any `None` max_episode_steps
if any(
(dataset.spec.env_spec.max_episode_steps is None)
for dataset in datasets_to_combine
):
max_episode_steps = None
else:
max_episode_steps = max(
dataset.spec.env_spec.max_episode_steps for dataset in datasets_to_combine
)

combine_env_spec = []
common_env_spec = copy.deepcopy(datasets_to_combine[0].spec.env_spec)
for dataset in datasets_to_combine:
dataset_env_spec = copy.deepcopy(dataset.spec.env_spec)
dataset_env_spec.max_episode_steps = max_episode_steps
combine_env_spec.append(dataset_env_spec)

assert all(
env_spec == combine_env_spec[0] for env_spec in combine_env_spec
), "The datasets to be combined have different values for `env_spec` attribute."
assert isinstance(dataset, MinariDataset)
env_spec = dataset.spec.env_spec
if env_spec is not None:
assert common_env_spec is not None, ""
if (
common_env_spec.max_episode_steps is None
or env_spec.max_episode_steps is None
):
common_env_spec.max_episode_steps = None
else:
common_env_spec.max_episode_steps = max(
common_env_spec.max_episode_steps, env_spec.max_episode_steps
)

return combine_env_spec[0]
return common_env_spec


class RandomPolicy:
Expand Down Expand Up @@ -214,9 +210,6 @@ def combine_datasets(datasets_to_combine: List[MinariDataset], new_dataset_id: s
Returns:
combined_dataset (MinariDataset): the resulting MinariDataset
"""
if any((dataset.spec.env_spec is None) for dataset in datasets_to_combine):
raise ValueError("One or more datasets have No Env_Spec")

combined_dataset_env_spec = validate_datasets_to_combine(datasets_to_combine)

# Compute intersection of Minari version specifiers
Expand Down Expand Up @@ -403,8 +396,10 @@ def create_dataset_from_buffers(
)

if observation_space is None:
assert env is not None
observation_space = env.observation_space
if action_space is None:
assert env is not None
action_space = env.action_space

if expert_policy is not None and ref_max_score is not None:
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_dataset_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_combine_datasets():
assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids
assert combined_dataset.spec.total_episodes == num_datasets * num_episodes
assert combined_dataset.spec.total_steps == sum(
d.spec.total_steps for d in test_datasets
int(d.spec.total_steps) for d in test_datasets
)
_check_env_recovery(gym.make("CartPole-v1"), combined_dataset)

Expand Down Expand Up @@ -173,6 +173,7 @@ def test_combine_datasets():
combined_dataset = combine_datasets(
test_datasets, new_dataset_id="cartpole-combined-test-v0"
)
assert combined_dataset.spec.env_spec is not None
assert combined_dataset.spec.env_spec.max_episode_steps == 10
_check_load_and_delete_dataset("cartpole-combined-test-v0")

Expand Down

0 comments on commit 1a61d32

Please sign in to comment.