From 1fd07b00b4c9dda67f104650f48f29ebd852d666 Mon Sep 17 00:00:00 2001 From: yut23 Date: Tue, 4 Apr 2023 23:27:42 -0400 Subject: [PATCH] ENH: allow hint keyword for yt.load to select superclasses This allows users to explicitly select a specific class in a Dataset hierarchy, even if one or more of its subclasses are valid. --- yt/loaders.py | 6 +++++- yt/tests/test_load_errors.py | 6 ++++++ yt/utilities/hierarchy_inspection.py | 14 +++----------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/yt/loaders.py b/yt/loaders.py index 0974abef632..b4ca4c474a2 100644 --- a/yt/loaders.py +++ b/yt/loaders.py @@ -116,8 +116,12 @@ def load( if cls._is_valid(fn, *args, **kwargs): candidates.append(cls) + # Filter the candidates if a hint was given + if hint is not None: + candidates = [c for c in candidates if hint.lower() in c.__name__.lower()] + # Find only the lowest subclasses, i.e. most specialised front ends - candidates = find_lowest_subclasses(candidates, hint=hint) + candidates = find_lowest_subclasses(candidates) if len(candidates) == 1: return candidates[0](fn, *args, **kwargs) diff --git a/yt/tests/test_load_errors.py b/yt/tests/test_load_errors.py index a93743b45da..f16092c17ed 100644 --- a/yt/tests/test_load_errors.py +++ b/yt/tests/test_load_errors.py @@ -96,6 +96,10 @@ def _set_code_unit_attributes(self, *args, **kwargs): self.mass_unit = self.quan(1, "kg") self.time_unit = self.quan(1, "s") + @classmethod + def _is_valid(cls, *args, **kwargs): + return True + class AlphaDataset(MockDataset): @classmethod def _is_valid(cls, *args, **kwargs): @@ -139,6 +143,8 @@ def test_load_ambiguous_data(tmp_path): ("beta", "BetaDataset"), ("BeTA", "BetaDataset"), ("b", "BetaDataset"), + ("mock", "MockDataset"), + ("MockDataset", "MockDataset"), ], ) @pytest.mark.usefixtures("ambiguous_dataset_classes") diff --git a/yt/utilities/hierarchy_inspection.py b/yt/utilities/hierarchy_inspection.py index 015e3c5d7b0..7e62ed1684f 100644 --- a/yt/utilities/hierarchy_inspection.py +++ b/yt/utilities/hierarchy_inspection.py @@ -1,12 +1,10 @@ import inspect from collections import Counter from functools import reduce -from typing import List, Optional, Type +from typing import List, Type -def find_lowest_subclasses( - candidates: List[Type], *, hint: Optional[str] = None -) -> List[Type]: +def find_lowest_subclasses(candidates: List[Type]) -> List[Type]: """ This function takes a list of classes, and returns only the ones that are are not super classes of any others in the list. i.e. the ones that are at @@ -18,9 +16,6 @@ def find_lowest_subclasses( An iterable object that is a collection of classes to find the lowest subclass of. - hint : str, optional - Only keep candidates classes that have `hint` in their name (case insensitive) - Returns ------- result : list @@ -44,7 +39,4 @@ def find_lowest_subclasses( count = reduce(lambda x, y: x + y, counters) - retv = [x for x in count.keys() if count[x] == 1] - if hint is not None: - retv = [x for x in retv if hint.lower() in x.__name__.lower()] - return retv + return [x for x in count.keys() if count[x] == 1]