Skip to content

Commit

Permalink
Merge pull request #4397 from yut23/load_hint_superclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros authored Apr 6, 2023
2 parents 93b47d9 + 1fd07b0 commit 45ca56c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
6 changes: 5 additions & 1 deletion yt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ def load(
if cls._is_valid(fn, *args, **kwargs):
candidates.append(cls)

# Filter the candidates if a hint was given
if hint is not None:
candidates = [c for c in candidates if hint.lower() in c.__name__.lower()]

# Find only the lowest subclasses, i.e. most specialised front ends
candidates = find_lowest_subclasses(candidates, hint=hint)
candidates = find_lowest_subclasses(candidates)

if len(candidates) == 1:
return candidates[0](fn, *args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions yt/tests/test_load_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def _set_code_unit_attributes(self, *args, **kwargs):
self.mass_unit = self.quan(1, "kg")
self.time_unit = self.quan(1, "s")

@classmethod
def _is_valid(cls, *args, **kwargs):
return True

class AlphaDataset(MockDataset):
@classmethod
def _is_valid(cls, *args, **kwargs):
Expand Down Expand Up @@ -139,6 +143,8 @@ def test_load_ambiguous_data(tmp_path):
("beta", "BetaDataset"),
("BeTA", "BetaDataset"),
("b", "BetaDataset"),
("mock", "MockDataset"),
("MockDataset", "MockDataset"),
],
)
@pytest.mark.usefixtures("ambiguous_dataset_classes")
Expand Down
14 changes: 3 additions & 11 deletions yt/utilities/hierarchy_inspection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import inspect
from collections import Counter
from functools import reduce
from typing import List, Optional, Type
from typing import List, Type


def find_lowest_subclasses(
candidates: List[Type], *, hint: Optional[str] = None
) -> List[Type]:
def find_lowest_subclasses(candidates: List[Type]) -> List[Type]:
"""
This function takes a list of classes, and returns only the ones that are
are not super classes of any others in the list. i.e. the ones that are at
Expand All @@ -18,9 +16,6 @@ def find_lowest_subclasses(
An iterable object that is a collection of classes to find the lowest
subclass of.
hint : str, optional
Only keep candidates classes that have `hint` in their name (case insensitive)
Returns
-------
result : list
Expand All @@ -44,7 +39,4 @@ def find_lowest_subclasses(

count = reduce(lambda x, y: x + y, counters)

retv = [x for x in count.keys() if count[x] == 1]
if hint is not None:
retv = [x for x in retv if hint.lower() in x.__name__.lower()]
return retv
return [x for x in count.keys() if count[x] == 1]

0 comments on commit 45ca56c

Please sign in to comment.