Skip to content

Commit

Permalink
Comvert TypeParent to class, speed up self-reference fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vemel committed Sep 23, 2024
1 parent 2cee5c7 commit 02e99ac
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 154 deletions.
2 changes: 1 addition & 1 deletion mypy_boto3_builder/parsers/service_package_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parse(self) -> ServicePackage:
)

self.shape_parser.fix_typed_dict_names()
self.shape_parser.fix_method_arguments_for_mypy(
self.shape_parser.convert_input_arguments_to_unions(
[
*result.client.methods,
*(result.service_resource.methods if result.service_resource else []),
Expand Down
70 changes: 29 additions & 41 deletions mypy_boto3_builder/parsers/shape_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from mypy_boto3_builder.type_annotations.type import Type
from mypy_boto3_builder.type_annotations.type_constant import TypeConstant
from mypy_boto3_builder.type_annotations.type_literal import TypeLiteral
from mypy_boto3_builder.type_annotations.type_parent import TypeParent
from mypy_boto3_builder.type_annotations.type_subscript import TypeSubscript
from mypy_boto3_builder.type_annotations.type_typed_dict import TypeTypedDict
from mypy_boto3_builder.type_annotations.type_union import TypeUnion
Expand Down Expand Up @@ -570,7 +571,7 @@ def parse_shape(
if isinstance(result, TypeTypedDict):
replacement = Type.DictStrAny if is_output_or_child else Type.MappingStrAny
mutated_parents = result.replace_self_references(replacement)
for mutated_parent in mutated_parents:
for mutated_parent in sorted(mutated_parents):
self.logger.debug(
f"Replaced self reference for {result.render()} in {mutated_parent.render()}"
)
Expand Down Expand Up @@ -1086,51 +1087,38 @@ def fix_typed_dict_names(self) -> None:

self._response_typed_dict_map.rename(response_typed_dict, new_typed_dict_name)

def fix_method_arguments_for_mypy(self, methods: Sequence[Method]) -> None:
@staticmethod
def _get_parent_type_annotations(
methods: Sequence[Method],
) -> set[TypeParent]:
result: set[TypeParent] = set()
for method in methods:
for argument in method.arguments:
if isinstance(argument.type_annotation, TypeParent):
result.add(argument.type_annotation)
return result

def convert_input_arguments_to_unions(self, methods: Sequence[Method]) -> None:
"""
Accept both input and output shapes in method arguments.
mypy does not compare TypedDicts, so we need to accept both input and output shapes.
https://github.com/youtype/mypy_boto3_builder/issues/209
"""
parent_type_annotations = list(self._get_parent_type_annotations(methods))
for input_typed_dict, output_typed_dict in self._fixed_typed_dict_map.items():
for method in methods:
for argument in method.arguments:
if not argument.type_annotation:
continue
if (
argument.type_annotation.is_typed_dict()
and argument.type_annotation == input_typed_dict
):
self.logger.debug(
f"Adding output shape to {method.name} {argument.name} type:"
f" {input_typed_dict.name} | {output_typed_dict.name}"
)
union_name = self._get_non_clashing_typed_dict_name(
input_typed_dict, "Union"
)
argument.type_annotation = TypeUnion(
name=union_name,
children=[input_typed_dict, output_typed_dict],
)
union_name = self._get_non_clashing_typed_dict_name(input_typed_dict, "Union")
union_type_annotation = TypeUnion(
name=union_name,
children=[input_typed_dict, output_typed_dict],
)
for type_annotation in parent_type_annotations:
parents = type_annotation.find_type_annotation_parents(input_typed_dict)
for parent in sorted(parents):
if parent is union_type_annotation:
continue
if isinstance(argument.type_annotation, TypeSubscript):
parent = argument.type_annotation.find_type_annotation_parent(
input_typed_dict
)
if parent:
self.logger.debug(
f"Adding output shape to {method.name} {argument.name} type:"
f" {input_typed_dict.name} | {output_typed_dict.name}"
)
union_name = self._get_non_clashing_typed_dict_name(
input_typed_dict, "Union"
)
parent.replace_child(
input_typed_dict,
TypeUnion(
name=union_name,
children=[input_typed_dict, output_typed_dict],
),
)
continue
self.logger.debug(
f"Adding output shape to {parent.render()} type:"
f" {input_typed_dict.name} | {output_typed_dict.name}"
)
parent.replace_child(input_typed_dict, union_type_annotation)
2 changes: 1 addition & 1 deletion mypy_boto3_builder/parsers/typed_dict_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ def get_sorted_names(self) -> list[str]:
"""
Get real TypedDict names topologically sorted.
"""
sorted_values = TypeDefSorter(self.values()).sort()
sorted_values = TypeDefSorter(set(self.values())).sort()
allowed_names = {i.name for i in self.values()}
return [i.name for i in sorted_values if i.name in allowed_names]
3 changes: 1 addition & 2 deletions mypy_boto3_builder/type_annotations/type_def_sortable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from mypy_boto3_builder.import_helpers.import_record import ImportRecord
from mypy_boto3_builder.type_annotations.fake_annotation import FakeAnnotation
from mypy_boto3_builder.type_annotations.type_literal import TypeLiteral
from mypy_boto3_builder.type_annotations.type_parent import TypeParent


@runtime_checkable
class TypeDefSortable(TypeParent, Protocol):
class TypeDefSortable(Protocol):
"""
Sortable protocol for TypeDefSorter.
"""
Expand Down
70 changes: 52 additions & 18 deletions mypy_boto3_builder/type_annotations/type_parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,84 @@
Protocol for types with children.
"""

from collections.abc import Iterator
from typing import Any, Protocol, Self, runtime_checkable
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from typing import Self

from mypy_boto3_builder.type_annotations.fake_annotation import FakeAnnotation
from mypy_boto3_builder.type_annotations.type_def_sortable import TypeDefSortable


@runtime_checkable
class TypeParent(Protocol):
class TypeParent(FakeAnnotation, ABC):
"""
Protocol for types with children.
"""

def iterate_children(self) -> Iterator[Any]:
@abstractmethod
def replace_child(self, child: FakeAnnotation, new_child: FakeAnnotation) -> Self:
"""
Iterate over children.
Replace child type annotation with a new one.
"""
...

def replace_child(self, child: FakeAnnotation, new_child: FakeAnnotation) -> Self:
@abstractmethod
def iterate_children_type_annotations(self) -> Iterator[FakeAnnotation]:
"""
Replace child type annotation with a new one.
Iterate over children type annotations.
"""
...

def find_type_annotation_parent(self, type_annotation: FakeAnnotation) -> "TypeParent | None":
@abstractmethod
def get_children_types(self) -> set[FakeAnnotation]:
"""
Check recursively if child is present in type def.
Extract required type annotations from attributes.
"""
...

def iterate_children_types(self) -> Iterator[FakeAnnotation]:
def find_type_annotation_parents(
self, type_annotation: FakeAnnotation, skip: Iterable[FakeAnnotation] = ()
) -> "set[TypeParent]":
"""
Iterate over children type annotations.
Check recursively if child is present in type def.
"""
...
result: set[TypeParent] = set()
for child_type in self.iterate_children_type_annotations():
if child_type == type_annotation:
result.add(self)
if not isinstance(child_type, TypeParent):
continue

def replace_self_references(self, replacement: FakeAnnotation) -> "list[TypeParent]":
if child_type in skip:
continue

parents = child_type.find_type_annotation_parents(
type_annotation, skip={*skip, child_type}
)
result.update(parents)

return result

def replace_self_references(self, replacement: FakeAnnotation) -> "set[TypeParent]":
"""
Replace self references with a new type annotation to avoid recursion.
"""
...
"""
Replace self references with a new type annotation to avoid recursion.
"""
parents = self.find_type_annotation_parents(self)
for parent in parents:
parent.replace_child(self, replacement)
return parents

def render(self) -> str:
def get_sortable_children(self) -> list[TypeDefSortable]:
"""
Render type annotation.
Extract required TypeDefSortable list from attributes.
"""
...
result: list[TypeDefSortable] = []
children_types = self.get_children_types()
for type_annotation in children_types:
if not isinstance(type_annotation, TypeDefSortable):
continue
result.append(type_annotation)

return result
36 changes: 8 additions & 28 deletions mypy_boto3_builder/type_annotations/type_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mypy_boto3_builder.type_annotations.type_parent import TypeParent


class TypeSubscript(FakeAnnotation, TypeParent):
class TypeSubscript(TypeParent):
"""
Wrapper for subscript type annotations, like `List[str]`.
Expand Down Expand Up @@ -102,40 +102,20 @@ def get_local_types(self) -> list[FakeAnnotation]:
result.extend(child.get_local_types())
return result

def iterate_children_types(self) -> Iterator[FakeAnnotation]:
def iterate_children_type_annotations(self) -> Iterator[FakeAnnotation]:
"""
Extract required type annotations from attributes.
"""
yield from self.children

def find_type_annotation_parent(
self: Self, type_annotation: FakeAnnotation
) -> TypeParent | None:
def get_children_types(self) -> set[FakeAnnotation]:
"""
Check recursively if child is present in subscript.
"""
for child_type in self.iterate_children_types():
if child_type == type_annotation:
return self
if isinstance(child_type, TypeParent):
result = child_type.find_type_annotation_parent(type_annotation)
if result is not None:
return result

return None

def replace_self_references(self, replacement: FakeAnnotation) -> list[TypeParent]:
"""
Replace self references with a new type annotation to avoid recursion.
Extract required type annotations from attributes.
"""
result: list[TypeParent] = []
while True:
parent = self.find_type_annotation_parent(self)
if parent is None:
return result

parent.replace_child(self, replacement)
result.append(parent)
result: set[FakeAnnotation] = set()
for child in self.children:
result.update(child.iterate_types())
return result

def replace_child(self, child: FakeAnnotation, new_child: FakeAnnotation) -> Self:
"""
Expand Down
52 changes: 6 additions & 46 deletions mypy_boto3_builder/type_annotations/type_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def mark_as_required(self) -> None:
self.required = True


class TypeTypedDict(FakeAnnotation, TypeDefSortable):
class TypeTypedDict(TypeParent, TypeDefSortable):
"""
Wrapper for `typing/typing_extensions.TypedDict` type annotations.
Expand Down Expand Up @@ -236,26 +236,13 @@ def get_children_types(self) -> set[FakeAnnotation]:
result.update(child.iterate_types())
return result

def iterate_children_types(self) -> Iterator[FakeAnnotation]:
def iterate_children_type_annotations(self) -> Iterator[FakeAnnotation]:
"""
Extract required type annotations from attributes.
"""
for child in self.children:
yield child.type_annotation

def get_sortable_children(self) -> list[TypeDefSortable]:
"""
Extract required TypeDefSortable list from attributes.
"""
result: list[TypeDefSortable] = []
children_types = self.get_children_types()
for type_annotation in children_types:
if not isinstance(type_annotation, TypeDefSortable):
continue
result.append(type_annotation)

return result

def get_children_literals(self, processed: Iterable[str] = ()) -> set[TypeLiteral]:
"""
Extract required TypeLiteral list from attributes.
Expand Down Expand Up @@ -315,35 +302,8 @@ def replace_child(self, child: FakeAnnotation, new_child: FakeAnnotation) -> Sel
if child not in children_types:
raise TypeAnnotationError(f"Child not found: {child}")

index = children_types.index(child)
self.children[index].type_annotation = new_child
return self

def find_type_annotation_parent(
self: Self, type_annotation: FakeAnnotation
) -> TypeParent | None:
"""
Check recursively if child is present in subscript.
"""
for child_type in self.iterate_children_types():
if child_type == type_annotation:
return self
if isinstance(child_type, TypeParent):
result = child_type.find_type_annotation_parent(type_annotation)
if result is not None:
return result

return None
indices = [i for i, x in enumerate(children_types) if x == child]
for index in indices:
self.children[index].type_annotation = new_child

def replace_self_references(self, replacement: FakeAnnotation) -> list[TypeParent]:
"""
Replace self references with a new type annotation to avoid recursion.
"""
result: list[TypeParent] = []
while True:
parent = self.find_type_annotation_parent(self)
if parent is None:
return result

parent.replace_child(self, replacement)
result.append(parent)
return self
13 changes: 0 additions & 13 deletions mypy_boto3_builder/type_annotations/type_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,6 @@ def get_children_types(self) -> set[FakeAnnotation]:
result.update(child.iterate_types())
return result

def get_sortable_children(self) -> list[TypeDefSortable]:
"""
Extract required TypeDefSortable list from attributes.
"""
result: list[TypeDefSortable] = []
children_types = self.get_children_types()
for type_annotation in children_types:
if not isinstance(type_annotation, TypeDefSortable):
continue
result.append(type_annotation)

return result

def get_children_literals(self, processed: Iterable[str] = ()) -> set[TypeLiteral]:
"""
Extract required TypeLiteral list from attributes.
Expand Down
8 changes: 4 additions & 4 deletions tests/type_annotations/test_type_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_is_list(self) -> None:
def test_copy(self) -> None:
assert self.result.copy().parent == Type.Dict

def test_find_type_annotation_parent(self) -> None:
def test_find_type_annotation_parents(self) -> None:
inner = TypeSubscript(Type.List, [Type.int])
outer = TypeSubscript(Type.Dict, [Type.str, inner])
assert outer.find_type_annotation_parent(Type.int) == inner
assert outer.find_type_annotation_parent(Type.str) == outer
assert outer.find_type_annotation_parent(Type.List) is None
assert outer.find_type_annotation_parents(Type.int) == {inner}
assert outer.find_type_annotation_parents(Type.str) == {outer}
assert outer.find_type_annotation_parents(Type.List) == set()

def test_replace_child(self) -> None:
inner = TypeSubscript(Type.List, [Type.int])
Expand Down

0 comments on commit 02e99ac

Please sign in to comment.