Skip to content

Commit

Permalink
RFC: rewrite find_lowest_subclasses to better match its intent
Browse files Browse the repository at this point in the history
This new implementation is more straightforward with fewer edge cases
that need special handling.
  • Loading branch information
yut23 committed Apr 10, 2023
1 parent c29aae4 commit b083bdf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
21 changes: 4 additions & 17 deletions yt/utilities/hierarchy_inspection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import inspect
from collections import Counter
from functools import reduce
from typing import List, Type


Expand All @@ -23,20 +21,9 @@ def find_lowest_subclasses(candidates: List[Type]) -> List[Type]:
candidates.
"""

# If there is only one input, the input candidate is always the
# lowest class
if len(candidates) == 1:
return candidates
elif len(candidates) == 0:
return []
# The MRO list includes the class itself, which we don't want
superclasses = [set(inspect.getmro(c)) - {c} for c in candidates]

mros = [inspect.getmro(c) for c in candidates]
all_superclasses = set().union(*superclasses)

counters = [Counter(mro) for mro in mros]

if len(counters) == 0:
return []

count = reduce(lambda x, y: x + y, counters)

return [x for x in candidates if count[x] == 1]
return [x for x in candidates if x not in all_superclasses]
13 changes: 12 additions & 1 deletion yt/utilities/tests/test_hierarchy_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class level4(level3):
pass


def test_empty():
result = find_lowest_subclasses([])
assert len(result) == 0


def test_single():
result = find_lowest_subclasses([level2])
assert len(result) == 1
Expand Down Expand Up @@ -60,4 +65,10 @@ def test_diverging_tree():
def test_without_parents():
result = find_lowest_subclasses([level1, level3])
assert len(result) == 1
assert level3 in result
assert result[0] is level3


def test_without_grandparents():
result = find_lowest_subclasses([level1, level4])
assert len(result) == 1
assert result[0] is level4

0 comments on commit b083bdf

Please sign in to comment.