Skip to content

Commit

Permalink
Merge pull request #230 from dkpro/feature/229-Get-transitive-closure…
Browse files Browse the repository at this point in the history
…-of-types

#229 - Get transitive closure of types
  • Loading branch information
reckart authored Sep 29, 2021
2 parents cab4e46 + d50efa4 commit f91fced
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
37 changes: 35 additions & 2 deletions cassis/typesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from io import BytesIO
from itertools import chain, filterfalse
from pathlib import Path
from typing import IO, Any, Callable, Dict, Iterator, List, Optional, Union
from typing import IO, Any, Callable, Dict, Iterator, List, Optional, Set, Union

import attr
from deprecation import deprecated
Expand Down Expand Up @@ -401,7 +401,7 @@ def __lt__(self, other):
return self.name < other.name


@attr.s(slots=True)
@attr.s(slots=True, hash=False, eq=True)
class Type:
"""Describes types in a type system.
Expand Down Expand Up @@ -583,6 +583,12 @@ def subsumes(self, other_type: "Type") -> bool:

return False

def __hash__(self):
return hash(self.name)

def __eq__(self, other):
return self.name == other.name


class TypeSystem:
def __init__(self, add_document_annotation_type: bool = True):
Expand Down Expand Up @@ -966,6 +972,33 @@ def _add_document_annotation_type(self):
t = self.create_type(name=_DOCUMENT_ANNOTATION_TYPE, supertypeName="uima.tcas.Annotation")
self.create_feature(t, name="language", rangeType="uima.cas.String")

def transitive_closure(self, seed_types: Set[Type], built_in: bool = False) -> Set[Type]:
# Build transitive closure of used types by following parents, features, etc.
transitively_referenced_types = set()
openlist = []
openlist.extend(seed_types)
while openlist:
type_ = openlist.pop(0)

if type_ in transitively_referenced_types:
continue

if not built_in and type_.name in _PREDEFINED_TYPES:
continue

transitively_referenced_types.add(type_)

if type_.supertype and type_.supertype not in transitively_referenced_types:
openlist.append(type_.supertype)

for feature in type_.all_features:
if feature.rangeType not in transitively_referenced_types:
openlist.append(feature.rangeType)
if feature.elementType and feature.elementType not in transitively_referenced_types:
openlist.append(feature.elementType)

return transitively_referenced_types


# Deserializing

Expand Down
41 changes: 37 additions & 4 deletions tests/test_typesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
from cassis.typesystem import (
_COLLECTION_TYPES,
TOP_TYPE_NAME,
TYPE_NAME_ANNOTATION,
TYPE_NAME_ANNOTATION_BASE,
TYPE_NAME_ARRAY_BASE,
TYPE_NAME_BOOLEAN,
TYPE_NAME_INTEGER,
TYPE_NAME_SOFA,
TYPE_NAME_STRING,
TYPE_NAME_STRING_ARRAY,
TYPE_NAME_TOP,
TypeCheckError,
)
Expand Down Expand Up @@ -287,7 +292,7 @@ def test_is_instance_of(child_name: str, parent_name: str, expected: bool):
# manually load the type system
path = os.path.join(FIXTURE_DIR, "typesystems", "important_dkpro_types.xml")

with open(path, "r") as f:
with open(path) as f:
ts = load_typesystem(f.read())

assert ts.is_instance_of(child_name, parent_name) == expected
Expand Down Expand Up @@ -643,7 +648,7 @@ def test_that_typesystem_with_redefined_documentation_annotation_works(
],
)
def test_that_merging_compatible_typesystem_works(name, rangeTypeName, elementType, multipleReferencesAllowed):
with open(typesystem_merge_base_path(), "r") as f:
with open(typesystem_merge_base_path()) as f:
base = load_typesystem(f.read())

ts = TypeSystem()
Expand Down Expand Up @@ -677,7 +682,7 @@ def test_that_merging_compatible_typesystem_works(name, rangeTypeName, elementTy
],
)
def test_that_merging_incompatible_typesystem_throws(name, rangeTypeName, elementType, multipleReferencesAllowed):
with open(typesystem_merge_base_path(), "r") as f:
with open(typesystem_merge_base_path()) as f:
base = load_typesystem(f.read())

ts = TypeSystem()
Expand All @@ -692,7 +697,7 @@ def test_that_merging_incompatible_typesystem_throws(name, rangeTypeName, elemen

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with pytest.raises(ValueError, match=r".*\[{0}\].*".format(name)):
with pytest.raises(ValueError, match=fr".*\[{name}\].*"):
merge_typesystems(base, ts)


Expand Down Expand Up @@ -870,3 +875,31 @@ def test_create_same_type_twice_fails():
typesystem.create_type("my.Type")
with pytest.raises(ValueError):
typesystem.create_type("my.Type")


def test_transitive_closure():
typesystem = TypeSystem()
base_type = typesystem.create_type("BaseType", supertypeName=TYPE_NAME_ANNOTATION)
child_type = typesystem.create_type("ChildType", supertypeName="BaseType")
typesystem.create_feature("ChildType", "primitiveFeature", TYPE_NAME_STRING)
typesystem.create_feature("ChildType", "arrayFeature", TYPE_NAME_STRING_ARRAY, elementType=TYPE_NAME_STRING)
typesystem.create_feature("ChildType", "fsFeature", "BaseType")

transitive_closure_without_builtins = typesystem.transitive_closure({child_type}, built_in=False)

assert transitive_closure_without_builtins == {base_type, child_type}

transitive_closure_with_builtins = typesystem.transitive_closure({child_type}, built_in=True)

assert transitive_closure_with_builtins == {
base_type,
child_type,
typesystem.get_type(TYPE_NAME_TOP),
typesystem.get_type(TYPE_NAME_ANNOTATION_BASE),
typesystem.get_type(TYPE_NAME_ANNOTATION),
typesystem.get_type(TYPE_NAME_STRING),
typesystem.get_type(TYPE_NAME_ARRAY_BASE),
typesystem.get_type(TYPE_NAME_STRING_ARRAY),
typesystem.get_type(TYPE_NAME_INTEGER),
typesystem.get_type(TYPE_NAME_SOFA),
}

0 comments on commit f91fced

Please sign in to comment.