diff --git a/ontopy/graph.py b/ontopy/graph.py index 7eafefbdc..0ad9e578a 100644 --- a/ontopy/graph.py +++ b/ontopy/graph.py @@ -371,7 +371,6 @@ def add_branch( # pylint: disable=too-many-arguments,too-many-locals also included.""" if leaves is None: leaves = () - classes = self.ontology.get_branch( root=root, leaves=leaves, @@ -400,9 +399,17 @@ def add_branch( # pylint: disable=too-many-arguments,too-many-locals nodeattrs=nodeattrs, **attrs, ) - + closest_ancestors = False + ancestor_generations = None + if include_parents == "closest": + closest_ancestors = True + elif isinstance(include_parents, int): + ancestor_generations = include_parents parents = self.ontology.get_ancestors( - classes, include=include_parents, strict=True + classes, + closest=closest_ancestors, + generations=ancestor_generations, + strict=True, ) if parents: for parent in parents: diff --git a/ontopy/ontology.py b/ontopy/ontology.py index e2980a859..de97ce424 100644 --- a/ontopy/ontology.py +++ b/ontopy/ontology.py @@ -4,7 +4,7 @@ If desirable some of these additions may be moved back into owlready2. """ # pylint: disable=too-many-lines,fixme,arguments-differ,protected-access -from typing import TYPE_CHECKING, Optional, Union, Sequence +from typing import TYPE_CHECKING, Optional, Union import os import itertools import inspect @@ -1513,18 +1513,27 @@ def closest_common_ancestor(*classes): "A closest common ancestor should always exist !" ) - def get_ancestors(self, classes, include="all", strict=True): + def get_ancestors( + self, + classes: "Union[List, ThingClass]", + closest: bool = False, + generations: int = None, + strict: bool = True, + ) -> set: """Return ancestors of all classes in `classes`. - classes to be provided as list - - The values of `include` may be: - - None: ignore this argument - - "all": Include all ancestors. - - "closest": Include all ancestors up to the closest common - ancestor of all classes. - - int: Include this number of ancestor levels. Here `include` - may be an integer or a string that can be converted to int. + Args: + classes: class(es) for which ancestors should be returned. + generations: Include this number of generations, default is all. + closest: If True, return all ancestors up to and including the + closest common ancestor. Return all if False. + strict: If True returns only real ancestors, i.e. `classes` are + are not included in the returned set. + Returns: + Set of ancestors to `classes`. """ + if not isinstance(classes, Iterable): + classes = [classes] + ancestors = set() if not classes: return ancestors @@ -1535,22 +1544,24 @@ def addancestors(entity, counter, subject): subject.add(parent) addancestors(parent, counter - 1, subject) - if isinstance(include, str) and include.isdigit(): - include = int(include) + if closest: + if generations is not None: + raise ValueError( + "Only one of `generations` or `closest` may be specified." + ) - if include == "all": - ancestors.update(*(_.ancestors() for _ in classes)) - elif include == "closest": - closest = self.closest_common_ancestor(*classes) + closest_ancestor = self.closest_common_ancestor(*classes) for cls in classes: ancestors.update( - _ for _ in cls.ancestors() if closest in _.ancestors() + anc + for anc in cls.ancestors() + if closest_ancestor in anc.ancestors() ) - elif isinstance(include, int): + elif isinstance(generations, int): for entity in classes: - addancestors(entity, int(include), ancestors) - elif include not in (None, "None", "none", ""): - raise ValueError('include must be "all", "closest" or None') + addancestors(entity, generations, ancestors) + else: + ancestors.update(*(cls.ancestors() for cls in classes)) if strict: return ancestors.difference(classes) @@ -1559,12 +1570,12 @@ def addancestors(entity, counter, subject): def get_descendants( self, classes: "Union[List, ThingClass]", - common: bool = False, generations: int = None, + common: bool = False, ) -> set: """Return descendants/subclasses of all classes in `classes`. Args: - classes: to be provided as list. + classes: class(es) for which descendants are desired. common: whether to only return descendants common to all classes. generations: Include this number of generations, default is all. Returns: @@ -1574,7 +1585,7 @@ def get_descendants( 'generations' defaults to all. """ - if not isinstance(classes, Sequence): + if not isinstance(classes, Iterable): classes = [classes] descendants = {name: [] for name in classes} diff --git a/tests/ontopy_tests/test_graph.py b/tests/ontopy_tests/test_graph.py index dbfc55891..0be76f9ec 100644 --- a/tests/ontopy_tests/test_graph.py +++ b/tests/ontopy_tests/test_graph.py @@ -76,6 +76,21 @@ class hasPartRenamed(owlready2.ObjectProperty): graph.add_legend() graph.save(tmpdir / "testonto.png") + with pytest.warns() as record: + graph2 = OntoGraph( + testonto, + testonto.TestClass, + relations="all", + addnodes=True, + edgelabels=None, + ) + assert str(record[0].message) == ( + "Style not defined for relation hasSpecialRelation. " + "Resorting to default style." + ) + graph2.add_legend() + graph2.save(tmpdir / "testonto2.png") + def test_emmo_graphs(emmo: "Ontology", tmpdir: "Path") -> None: """Testing OntoGraph on various aspects of EMMO. @@ -217,8 +232,8 @@ def test_emmo_graphs(emmo: "Ontology", tmpdir: "Path") -> None: graph = OntoGraph(emmo) graph.add_entities(semiotic, relations="all", edgelabels=False) graph.add_legend() - graph.save(tmpdir / "measurement.png") - + graph.save(tmpdir / "measurement.png", fmt="graphviz") + print("reductionistc") # Reductionistic perspective graph = OntoGraph( emmo, @@ -236,7 +251,53 @@ def test_emmo_graphs(emmo: "Ontology", tmpdir: "Path") -> None: edgelabels=None, ) graph.add_legend() - graph.save(tmpdir / "Reductionistic.png", fmt="graphviz") + graph.save(tmpdir / "Reductionistic.png") + + # Reductionistic perspective, choose leaf_generations + graph = OntoGraph( + emmo, + emmo.Reductionistic, + relations="all", + addnodes=False, + parents=2, + edgelabels=None, + ) + graph.add_branch( + emmo.Reductionistic, + leaves=[ + emmo.Quantity, + emmo.String, + emmo.PrefixedUnit, + emmo.SymbolicConstruct, + emmo.Matter, + ], + ) + + graph.add_legend() + graph.save(tmpdir / "Reductionistic_addbranch.png") + + graph2 = OntoGraph( + emmo, + emmo.Reductionistic, + relations="all", + addnodes=False, + # parents=2, + edgelabels=None, + ) + graph2.add_branch( + emmo.Reductionistic, + leaves=[ + emmo.Quantity, + emmo.String, + emmo.PrefixedUnit, + emmo.SymbolicConstruct, + emmo.Matter, + ], + include_parents=2, + ) + + graph2.add_legend() + graph2.save(tmpdir / "Reductionistic_addbranch_2.png") # View modules diff --git a/tests/test_generation_search.py b/tests/test_generation_search.py index b24028072..3afe2c17a 100755 --- a/tests/test_generation_search.py +++ b/tests/test_generation_search.py @@ -80,3 +80,70 @@ def test_descendants(emmo: "Ontology", repo_dir: "Path") -> None: assert onto.get_descendants([onto.Tree, onto.NaturalDye], common=True) == { onto.Avocado } + + +def test_ancestors(emmo: "Ontology", repo_dir: "Path") -> None: + from ontopy import get_ontology + from ontopy.utils import LabelDefinitionError + + ontopath = repo_dir / "tests" / "testonto" / "testontology.ttl" + + onto = get_ontology(ontopath).load() + + # Test that default gives all ancestors. + assert onto.get_ancestors(onto.NorwaySpruce) == { + onto.Spruce, + onto.Tree, + onto.EvergreenTree, + onto.Thing, + } + + # Test that asking for 0 generations returns empty set + assert onto.get_ancestors(onto.NorwaySpruce, generations=0) == set() + + # Check that number of generations are returned correctly + assert onto.get_ancestors(onto.NorwaySpruce, generations=2) == { + onto.Spruce, + onto.EvergreenTree, + } + + assert onto.get_ancestors(onto.NorwaySpruce, generations=1) == { + onto.Spruce, + } + # Check that no error is generated if one of the classes do + # not have enough parents for all given generations + assert onto.get_ancestors(onto.NorwaySpruce, generations=10) == ( + onto.get_ancestors(onto.NorwaySpruce) + ) + + # Check that ancestors of a list is returned correctly + assert onto.get_ancestors([onto.NorwaySpruce, onto.Avocado]) == { + onto.Tree, + onto.EvergreenTree, + onto.Spruce, + onto.NaturalDye, + onto.Thing, + } + # Check that classes up to closest common ancestor are returned + + assert onto.get_ancestors( + [onto.NorwaySpruce, onto.Avocado], closest=True + ) == { + onto.EvergreenTree, + onto.Spruce, + } + + with pytest.raises(ValueError): + onto.get_ancestors(onto.NorwaySpruce, closest=True, generations=4) + + # Test strict == False + assert onto.get_ancestors( + [onto.NorwaySpruce, onto.Avocado], + closest=True, + strict=False, + ) == { + onto.EvergreenTree, + onto.Spruce, + onto.NorwaySpruce, + onto.Avocado, + }