From 15487c0d7099ae7365e76d80469d4e35ae95cc48 Mon Sep 17 00:00:00 2001 From: "Eric T. Johnson" Date: Wed, 12 Apr 2023 10:22:45 -0400 Subject: [PATCH] RFC: reduce code complexity in find_lowest_subclasses helper function --- yt/utilities/hierarchy_inspection.py | 21 +++---------------- .../tests/test_hierarchy_inspection.py | 13 +++++++++++- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/yt/utilities/hierarchy_inspection.py b/yt/utilities/hierarchy_inspection.py index 8a2e0aae92f..65f9385c2bd 100644 --- a/yt/utilities/hierarchy_inspection.py +++ b/yt/utilities/hierarchy_inspection.py @@ -1,8 +1,9 @@ import inspect from collections import Counter -from functools import reduce from typing import List, Type +from more_itertools import flatten + def find_lowest_subclasses(candidates: List[Type]) -> List[Type]: """ @@ -22,21 +23,5 @@ def find_lowest_subclasses(candidates: List[Type]) -> List[Type]: A list of classes which are not super classes for any others in candidates. """ - - # If there is only one input, the input candidate is always the - # lowest class - if len(candidates) == 1: - return candidates - elif len(candidates) == 0: - return [] - - mros = [inspect.getmro(c) for c in candidates] - - counters = [Counter(mro) for mro in mros] - - if len(counters) == 0: - return [] - - count = reduce(lambda x, y: x + y, counters) - + count = Counter(flatten(inspect.getmro(c) for c in candidates)) return [x for x in candidates if count[x] == 1] diff --git a/yt/utilities/tests/test_hierarchy_inspection.py b/yt/utilities/tests/test_hierarchy_inspection.py index c5d36f90952..8783a5fcbab 100644 --- a/yt/utilities/tests/test_hierarchy_inspection.py +++ b/yt/utilities/tests/test_hierarchy_inspection.py @@ -27,6 +27,11 @@ class level4(level3): pass +def test_empty(): + result = find_lowest_subclasses([]) + assert len(result) == 0 + + def test_single(): result = find_lowest_subclasses([level2]) assert len(result) == 1 @@ -60,4 +65,10 @@ def test_diverging_tree(): def test_without_parents(): result = find_lowest_subclasses([level1, level3]) assert len(result) == 1 - assert level3 in result + assert result[0] is level3 + + +def test_without_grandparents(): + result = find_lowest_subclasses([level1, level4]) + assert len(result) == 1 + assert result[0] is level4