From d50efa4fc9909650f2bc49d6190a14b60a161793 Mon Sep 17 00:00:00 2001 From: Richard Eckart de Castilho Date: Wed, 29 Sep 2021 17:18:55 +0200 Subject: [PATCH] #229 - Get transitive closure of types - Added utility method - Added test - Modernize tests --- cassis/typesystem.py | 37 ++++++++++++++++++++++++++++++++++-- tests/test_typesystem.py | 41 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/cassis/typesystem.py b/cassis/typesystem.py index 0e209e1..44cfc05 100644 --- a/cassis/typesystem.py +++ b/cassis/typesystem.py @@ -4,7 +4,7 @@ from io import BytesIO from itertools import chain, filterfalse from pathlib import Path -from typing import IO, Any, Callable, Dict, Iterator, List, Optional, Union +from typing import IO, Any, Callable, Dict, Iterator, List, Optional, Set, Union import attr from deprecation import deprecated @@ -401,7 +401,7 @@ def __lt__(self, other): return self.name < other.name -@attr.s(slots=True) +@attr.s(slots=True, hash=False, eq=True) class Type: """Describes types in a type system. @@ -583,6 +583,12 @@ def subsumes(self, other_type: "Type") -> bool: return False + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return self.name == other.name + class TypeSystem: def __init__(self, add_document_annotation_type: bool = True): @@ -966,6 +972,33 @@ def _add_document_annotation_type(self): t = self.create_type(name=_DOCUMENT_ANNOTATION_TYPE, supertypeName="uima.tcas.Annotation") self.create_feature(t, name="language", rangeType="uima.cas.String") + def transitive_closure(self, seed_types: Set[Type], built_in: bool = False) -> Set[Type]: + # Build transitive closure of used types by following parents, features, etc. + transitively_referenced_types = set() + openlist = [] + openlist.extend(seed_types) + while openlist: + type_ = openlist.pop(0) + + if type_ in transitively_referenced_types: + continue + + if not built_in and type_.name in _PREDEFINED_TYPES: + continue + + transitively_referenced_types.add(type_) + + if type_.supertype and type_.supertype not in transitively_referenced_types: + openlist.append(type_.supertype) + + for feature in type_.all_features: + if feature.rangeType not in transitively_referenced_types: + openlist.append(feature.rangeType) + if feature.elementType and feature.elementType not in transitively_referenced_types: + openlist.append(feature.elementType) + + return transitively_referenced_types + # Deserializing diff --git a/tests/test_typesystem.py b/tests/test_typesystem.py index f5fd52c..e2b9df2 100644 --- a/tests/test_typesystem.py +++ b/tests/test_typesystem.py @@ -6,9 +6,14 @@ from cassis.typesystem import ( _COLLECTION_TYPES, TOP_TYPE_NAME, + TYPE_NAME_ANNOTATION, + TYPE_NAME_ANNOTATION_BASE, + TYPE_NAME_ARRAY_BASE, TYPE_NAME_BOOLEAN, TYPE_NAME_INTEGER, + TYPE_NAME_SOFA, TYPE_NAME_STRING, + TYPE_NAME_STRING_ARRAY, TYPE_NAME_TOP, TypeCheckError, ) @@ -287,7 +292,7 @@ def test_is_instance_of(child_name: str, parent_name: str, expected: bool): # manually load the type system path = os.path.join(FIXTURE_DIR, "typesystems", "important_dkpro_types.xml") - with open(path, "r") as f: + with open(path) as f: ts = load_typesystem(f.read()) assert ts.is_instance_of(child_name, parent_name) == expected @@ -643,7 +648,7 @@ def test_that_typesystem_with_redefined_documentation_annotation_works( ], ) def test_that_merging_compatible_typesystem_works(name, rangeTypeName, elementType, multipleReferencesAllowed): - with open(typesystem_merge_base_path(), "r") as f: + with open(typesystem_merge_base_path()) as f: base = load_typesystem(f.read()) ts = TypeSystem() @@ -677,7 +682,7 @@ def test_that_merging_compatible_typesystem_works(name, rangeTypeName, elementTy ], ) def test_that_merging_incompatible_typesystem_throws(name, rangeTypeName, elementType, multipleReferencesAllowed): - with open(typesystem_merge_base_path(), "r") as f: + with open(typesystem_merge_base_path()) as f: base = load_typesystem(f.read()) ts = TypeSystem() @@ -692,7 +697,7 @@ def test_that_merging_incompatible_typesystem_throws(name, rangeTypeName, elemen with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - with pytest.raises(ValueError, match=r".*\[{0}\].*".format(name)): + with pytest.raises(ValueError, match=fr".*\[{name}\].*"): merge_typesystems(base, ts) @@ -870,3 +875,31 @@ def test_create_same_type_twice_fails(): typesystem.create_type("my.Type") with pytest.raises(ValueError): typesystem.create_type("my.Type") + + +def test_transitive_closure(): + typesystem = TypeSystem() + base_type = typesystem.create_type("BaseType", supertypeName=TYPE_NAME_ANNOTATION) + child_type = typesystem.create_type("ChildType", supertypeName="BaseType") + typesystem.create_feature("ChildType", "primitiveFeature", TYPE_NAME_STRING) + typesystem.create_feature("ChildType", "arrayFeature", TYPE_NAME_STRING_ARRAY, elementType=TYPE_NAME_STRING) + typesystem.create_feature("ChildType", "fsFeature", "BaseType") + + transitive_closure_without_builtins = typesystem.transitive_closure({child_type}, built_in=False) + + assert transitive_closure_without_builtins == {base_type, child_type} + + transitive_closure_with_builtins = typesystem.transitive_closure({child_type}, built_in=True) + + assert transitive_closure_with_builtins == { + base_type, + child_type, + typesystem.get_type(TYPE_NAME_TOP), + typesystem.get_type(TYPE_NAME_ANNOTATION_BASE), + typesystem.get_type(TYPE_NAME_ANNOTATION), + typesystem.get_type(TYPE_NAME_STRING), + typesystem.get_type(TYPE_NAME_ARRAY_BASE), + typesystem.get_type(TYPE_NAME_STRING_ARRAY), + typesystem.get_type(TYPE_NAME_INTEGER), + typesystem.get_type(TYPE_NAME_SOFA), + }