Skip to content

Commit

Permalink
- Support base classes with typescript syntax
Browse files Browse the repository at this point in the history
- Fixed paginators inheritance
- Changed Paginator and Waiter RequestTypeDef names
  • Loading branch information
vemel committed Dec 19, 2024
1 parent 18b1372 commit d0c704f
Show file tree
Hide file tree
Showing 18 changed files with 223 additions and 183 deletions.
1 change: 1 addition & 0 deletions mypy_boto3_builder/parsers/service_package_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _parse_paginators(self) -> list[Paginator]:
paginator_name=paginator_name,
operation_name=operation_name,
service_name=self.service_name,
return_type=self.shape_parser.get_paginator_subscript(paginator_name),
)

paginate_method = self.shape_parser.get_paginate_method(paginator_name)
Expand Down
32 changes: 20 additions & 12 deletions mypy_boto3_builder/parsers/shape_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
StringShape,
StructureShape,
)
from botocore.paginate import PageIterator

from mypy_boto3_builder.boto3_ports.model import Collection as Boto3Collection
from mypy_boto3_builder.boto3_ports.model import ResourceModel
Expand Down Expand Up @@ -652,6 +653,21 @@ def parse_shape(

return result

def get_paginator_subscript(self, paginator_name: str) -> FakeAnnotation:
"""
Get Paginator return class.
"""
operation_name = paginator_name
operation_shape = self._get_operation(operation_name)
if operation_shape.output_shape is None:
return Type.Any

return self._parse_return_type(
"Paginator",
"paginate",
operation_shape.output_shape,
)

def get_paginate_method(self, paginator_name: str) -> Method:
"""
Get Paginator `paginate` method.
Expand Down Expand Up @@ -691,16 +707,8 @@ def get_paginate_method(self, paginator_name: str) -> Method:
arguments.extend(self._get_kw_flags("paginate", shape_arguments))
arguments.extend(shape_arguments)

return_type: FakeAnnotation = Type.none
if operation_shape.output_shape is not None:
page_iterator_import = InternalImport("_PageIterator")
return_item = self._parse_return_type(
"Paginator",
"paginate",
operation_shape.output_shape,
)
return_type = TypeSubscript(page_iterator_import, [return_item])

subscript = self.get_paginator_subscript(operation_name)
return_type = TypeSubscript(ExternalImport.from_class(PageIterator), [subscript])
method = Method(
name="paginate",
arguments=arguments,
Expand All @@ -711,7 +719,7 @@ def get_paginate_method(self, paginator_name: str) -> Method:
method.create_request_type_annotation(
self._get_typed_dict_name(
operation_shape.input_shape,
postfix=f"{paginator_name}Paginate",
postfix="Paginate",
),
)
return method
Expand Down Expand Up @@ -764,7 +772,7 @@ def get_wait_method(self, waiter_name: str) -> Method:
method.create_request_type_annotation(
self._get_typed_dict_name(
operation_shape.input_shape,
postfix=f"{waiter_name}Wait",
postfix="Wait",
),
)
return method
Expand Down
5 changes: 4 additions & 1 deletion mypy_boto3_builder/postprocessors/aio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from botocore.client import BaseClient
from botocore.config import Config
from botocore.eventstream import EventStream
from botocore.paginate import Paginator
from botocore.paginate import PageIterator, Paginator
from botocore.response import StreamingBody
from botocore.waiter import Waiter

Expand All @@ -31,6 +31,9 @@
ExternalImport.from_class(Paginator): ExternalImport(
Import.aiobotocore + "paginate", "AioPaginator"
),
ExternalImport.from_class(PageIterator): ExternalImport(
Import.aiobotocore + "paginate", "AioPageIterator"
),
ExternalImport.from_class(BaseClient): ExternalImport(
Import.aiobotocore + "client", "AioBaseClient"
),
Expand Down
11 changes: 0 additions & 11 deletions mypy_boto3_builder/postprocessors/aiobotocore.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def process_package(self) -> None:
Convert all methods to asynchronous.
"""
self._make_async_client()
self._make_async_paginators()
self._make_async_waiters()
self._make_async_service_resource()
self._make_async_collections()
Expand Down Expand Up @@ -96,16 +95,6 @@ def _make_async_collections(self) -> None:
for collection in self.package.service_resource.collections:
self._make_async_collection(collection)

def _make_async_paginators(self) -> None:
for paginator in self.package.paginators:
paginate_method = paginator.get_method("paginate")
if not isinstance(paginate_method.return_type, TypeSubscript):
raise TypeError(
f"{paginator.name}.paginate method return type is not TypeSubscript:"
f" {paginate_method.return_type.render()}",
)
paginate_method.return_type.parent = Type.AsyncIterator

def _make_async_waiters(self) -> None:
for waiter in self.package.waiters:
for method in waiter.methods:
Expand Down
59 changes: 59 additions & 0 deletions mypy_boto3_builder/structures/base_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Base class for all structures that can be rendered to a class.
Copyright 2024 Vlad Emelianov
"""

from collections.abc import Iterator

from mypy_boto3_builder.import_helpers.import_helper import Import
from mypy_boto3_builder.type_annotations.external_import import ExternalImport
from mypy_boto3_builder.type_annotations.fake_annotation import FakeAnnotation
from mypy_boto3_builder.utils.type_checks import is_type_subscript


class BaseClass:
"""
Base class for CLassRecord.
"""

def __init__(self, name: str, annotation: FakeAnnotation) -> None:
self.name = name
self.annotation = annotation

def render(self) -> str:
"""
Render for usage as class base.
If annotation is a subscript, return name.
"""
if is_type_subscript(self.annotation):
return self.name

return self.annotation.render()

def render_definition(self) -> str:
"""
Render definition that works in runtime and TYPE_CHECKING.
Returns empty string if annotation is not a subscript.
"""
annotation = self.annotation
if not is_type_subscript(annotation):
return ""

return (
"if TYPE_CHECKING:\n"
f" {self.name} = {annotation.render_definition()}\n"
"else:\n"
f" {self.name} = {annotation.parent.render_definition()}"
" # type: ignore[assignment]\n"
)

def iterate_types(self) -> Iterator[FakeAnnotation]:
"""
Iterate over definition type annotations.
"""
yield from self.annotation.iterate_types()
if is_type_subscript(self.annotation):
yield ExternalImport(Import.typing, "TYPE_CHECKING")
5 changes: 4 additions & 1 deletion mypy_boto3_builder/structures/class_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mypy_boto3_builder.exceptions import StructureError
from mypy_boto3_builder.import_helpers.import_record import ImportRecord
from mypy_boto3_builder.structures.attribute import Attribute
from mypy_boto3_builder.structures.base_class import BaseClass
from mypy_boto3_builder.structures.method import Method
from mypy_boto3_builder.type_annotations.fake_annotation import FakeAnnotation
from mypy_boto3_builder.utils.strings import xform_name
Expand All @@ -31,7 +32,9 @@ def __init__(
self.name = name
self.methods = list(methods)
self.attributes = list(attributes)
self.bases: list[FakeAnnotation] = list(bases)
self.bases = tuple(
BaseClass(f"_{name}Base{index or ''}", base) for index, base in enumerate(bases)
)
self.use_alias = use_alias
self.docstring = ""

Expand Down
3 changes: 1 addition & 2 deletions mypy_boto3_builder/structures/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ class Client(ClassRecord):
}

def __init__(self, name: str, service_name: ServiceName) -> None:
super().__init__(name=name)
super().__init__(name=name, bases=[ExternalImport.from_class(BaseClient)])
self.service_name = service_name
self.exceptions_class = ClassRecord(
name="Exceptions",
bases=(ExternalImport.from_class(BaseClientExceptions),),
)
self.bases = [ExternalImport.from_class(BaseClient)]

def __hash__(self) -> int:
"""
Expand Down
6 changes: 5 additions & 1 deletion mypy_boto3_builder/structures/paginator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from mypy_boto3_builder.structures.class_record import ClassRecord
from mypy_boto3_builder.structures.method import Method
from mypy_boto3_builder.type_annotations.external_import import ExternalImport
from mypy_boto3_builder.type_annotations.fake_annotation import FakeAnnotation
from mypy_boto3_builder.type_annotations.type import Type
from mypy_boto3_builder.type_annotations.type_literal import TypeLiteral
from mypy_boto3_builder.type_annotations.type_subscript import TypeSubscript


@functools.total_ordering
Expand All @@ -33,10 +35,12 @@ def __init__(
paginator_name: str,
operation_name: str,
service_name: ServiceName,
return_type: FakeAnnotation | None = None,
) -> None:
base_class = ExternalImport.from_class(BotocorePaginator)
super().__init__(
name=name,
bases=[ExternalImport.from_class(BotocorePaginator)],
bases=[base_class if not return_type else TypeSubscript(base_class, [return_type])],
)
self.operation_name = operation_name
self.paginator_name = paginator_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AioClientCreator(ClientCreator):
) -> AioBaseClient: ...

class AioBaseClient(BaseClient):
def get_paginator(self, operation_name: str) -> AioPaginator: ...
def get_paginator(self, operation_name: str) -> AioPaginator[Any]: ...
def get_waiter(self, waiter_name: str) -> AIOWaiter: ...
async def __aenter__(self: _R) -> _R: ...
async def __aexit__(
Expand Down
27 changes: 14 additions & 13 deletions mypy_boto3_builder/stubs_static/types-aiobotocore/paginate.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@ Type annotations for aiobotocore.paginate module.
Copyright 2024 Vlad Emelianov
"""

from typing import Any, AsyncIterator, Iterator
from typing import Any, AsyncIterator, Generic, TypeVar

from botocore.paginate import PageIterator, Paginator

class AioPageIterator(PageIterator):
def __aiter__(self) -> Any: ...
resume_token: Any
async def __anext__(self) -> Any: ...
_R = TypeVar("_R")

class AioPageIterator(PageIterator[_R], Generic[_R]):
def __aiter__(self) -> AsyncIterator[_R]: ...
async def __anext__(self) -> _R: ...
def result_key_iters(self) -> Any: ...
async def build_full_result(self) -> Any: ...
async def search(self, expression: str) -> Iterator[Any]: ... # type: ignore[override]
async def build_full_result(self) -> dict[str, Any]: ... # type: ignore[override]
def search(self, expression: str) -> AsyncIterator[_R]: ... # type: ignore[override]

class AioPaginator(Paginator):
PAGE_ITERATOR_CLS: type[AioPageIterator]
class AioPaginator(Paginator[_R], Generic[_R]):
PAGE_ITERATOR_CLS: type[AioPageIterator[Any]] # type: ignore[override]

class ResultKeyIterator:
class ResultKeyIterator(Generic[_R]):
result_key: str
def __init__(self, pages_iterator: PageIterator, result_key: str) -> None: ...
def __aiter__(self) -> AsyncIterator[Any]: ...
async def __anext__(self) -> Any: ...
def __init__(self, pages_iterator: PageIterator[_R], result_key: str) -> None: ...
def __aiter__(self) -> AsyncIterator[_R]: ...
async def __anext__(self) -> _R: ...
35 changes: 22 additions & 13 deletions mypy_boto3_builder/stubs_static/types-aiobotocore/signers.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,21 @@ class AioRequestSigner(RequestSigner):
expires_in: int = ...,
region_name: str | None = ...,
signing_name: str | None = ...,
) -> Any: ...
) -> str: ...

def add_generate_db_auth_token(class_attributes: Any, **kwargs: Any) -> None: ...
def add_generate_db_auth_token(class_attributes: dict[str, Any], **kwargs: Any) -> None: ...
def add_dsql_generate_db_auth_token_methods(
class_attributes: dict[str, Any], **kwargs: Any
) -> None: ...
async def generate_db_auth_token(
self: Any, DBHostname: Any, Port: Any, DBUsername: Any, Region: Any | None = ...
) -> Any: ...
def add_generate_presigned_url(class_attributes: Any, **kwargs: Any) -> None: ...
async def generate_presigned_url(
self: Any,
ClientMethod: str,
Params: Any | None = ...,
ExpiresIn: int = ...,
HttpMethod: Any | None = ...,
) -> Any: ...
self: Any, DBHostname: str, Port: int, DBUsername: str, Region: str | None = ...
) -> str: ...
async def dsql_generate_db_connect_auth_token(
self: Any, Hostname: str, Region: str | None = ..., ExpiresIn: int = ...
) -> str: ...
async def dsql_generate_db_connect_admin_auth_token(
self: Any, Hostname: str, Region: str | None = ..., ExpiresIn: int = ...
) -> str: ...

class AioS3PostPresigner(S3PostPresigner):
async def generate_presigned_post(
Expand All @@ -62,9 +63,17 @@ class AioS3PostPresigner(S3PostPresigner):
fields: Any | None = ...,
conditions: Any | None = ...,
expires_in: int = ...,
region_name: Any | None = ...,
region_name: str | None = ...,
) -> Any: ...

def add_generate_presigned_url(class_attributes: dict[str, Any], **kwargs: Any) -> None: ...
async def generate_presigned_url(
self: Any,
ClientMethod: str,
Params: Any | None = ...,
ExpiresIn: int = ...,
HttpMethod: str | None = ...,
) -> Any: ...
def add_generate_presigned_post(class_attributes: Any, **kwargs: Any) -> None: ...
async def generate_presigned_post(
self: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ Usage::

{{ copyright }}
"""
from typing import TypeVar, Generic, Iterator

from botocore.paginate import PageIterator

{% for import_record in package.get_paginator_required_import_records() -%}
{{ import_record -}}{{ "\n" -}}
{% endfor -%}
Expand All @@ -45,17 +41,6 @@ __all__ = (
{% endfor -%}
)


_ItemTypeDef = TypeVar("_ItemTypeDef")


class _PageIterator(PageIterator, Generic[_ItemTypeDef]):
def __iter__(self) -> Iterator[_ItemTypeDef]:
"""
Proxy method to specify iterator item type.
"""


{{ "\n\n" -}}

{% for paginator in package.paginators -%}
Expand Down
2 changes: 2 additions & 0 deletions mypy_boto3_builder/templates/common/class.py.jinja2
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{% for base in class.bases -%}{{ base.render_definition() -}}{% endfor -%}

class {{ class.name }}{% if class.bases %}({% for base in class.bases %}{{ base.render() }}{{ ", " if not loop.last else "" -}}{% endfor %}){% endif %}:
{% filter indent(4, True) -%}
{{ "pass" if not class.attributes and not class.methods and not class.docstring else "" -}}
Expand Down
Loading

0 comments on commit d0c704f

Please sign in to comment.