Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add typing to proto.Message based class attributes #1474

Merged
merged 10 commits into from
Nov 9, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
{% if "ListOperations" in api.mixin_api_methods %}
def list_operations(
self,
request: operations_pb2.ListOperationsRequest = None,
request: Optional[operations_pb2.ListOperationsRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operations_pb2.ListOperationsResponse:
r"""Lists operations that match the specified filter in the request.
Expand Down Expand Up @@ -57,10 +57,10 @@
{% if "GetOperation" in api.mixin_api_methods %}
def get_operation(
self,
request: operations_pb2.GetOperationRequest = None,
request: Optional[operations_pb2.GetOperationRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operations_pb2.Operation:
r"""Gets the latest state of a long-running operation.
Expand Down Expand Up @@ -110,10 +110,10 @@
{% if "DeleteOperation" in api.mixin_api_methods %}
def delete_operation(
self,
request: operations_pb2.DeleteOperationRequest = None,
request: Optional[operations_pb2.DeleteOperationRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
r"""Deletes a long-running operation.
Expand Down Expand Up @@ -163,10 +163,10 @@
{% if "CancelOperation" in api.mixin_api_methods %}
def cancel_operation(
self,
request: operations_pb2.CancelOperationRequest = None,
request: Optional[operations_pb2.CancelOperationRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
r"""Starts asynchronous cancellation on a long-running operation.
Expand Down Expand Up @@ -215,10 +215,10 @@
{% if "WaitOperation" in api.mixin_api_methods %}
def wait_operation(
self,
request: operations_pb2.WaitOperationRequest = None,
request: Optional[operations_pb2.WaitOperationRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operations_pb2.Operation:
r"""Waits until the specified long-running operation is done or reaches at most
Expand Down Expand Up @@ -274,10 +274,10 @@
{% if "SetIamPolicy" in api.mixin_api_methods %}
def set_iam_policy(
self,
request: iam_policy_pb2.SetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Sets the IAM access control policy on the specified function.
Expand Down Expand Up @@ -393,10 +393,10 @@
{% if "GetIamPolicy" in api.mixin_api_methods %}
def get_iam_policy(
self,
request: iam_policy_pb2.GetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Gets the IAM access control policy for a function.
Expand Down Expand Up @@ -513,10 +513,10 @@
{% if "TestIamPermissions" in api.mixin_api_methods %}
def test_iam_permissions(
self,
request: iam_policy_pb2.TestIamPermissionsRequest = None,
request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> iam_policy_pb2.TestIamPermissionsResponse:
r"""Tests the specified IAM permissions against the IAM access control
Expand Down Expand Up @@ -575,10 +575,10 @@
{% if "GetLocation" in api.mixin_api_methods %}
def get_location(
self,
request: locations_pb2.GetLocationRequest = None,
request: Optional[locations_pb2.GetLocationRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> locations_pb2.Location:
r"""Gets information about a location.
Expand Down Expand Up @@ -628,10 +628,10 @@
{% if "ListLocations" in api.mixin_api_methods %}
def list_locations(
self,
request: locations_pb2.ListLocationsRequest = None,
request: Optional[locations_pb2.ListLocationsRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> locations_pb2.ListLocationsResponse:
r"""Lists information about the supported locations for this service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
import os
import re
from typing import Callable, Dict, Mapping, Optional, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union, cast
from typing import Callable, Dict, Mapping, MutableMapping, MutableSequence, Optional, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union, cast
import pkg_resources
{% if service.any_deprecated %}
import warnings
Expand Down Expand Up @@ -68,7 +68,7 @@ class {{ service.client_name }}Meta(type):
{% endif %}

def get_transport_class(cls,
label: str = None,
label: Optional[str] = None,
) -> Type[{{ service.name }}Transport]:
"""Returns an appropriate transport class.

Expand Down Expand Up @@ -218,7 +218,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

def __init__(self, *,
credentials: Optional[ga_credentials.Credentials] = None,
transport: Union[str, {{ service.name }}Transport, None] = None,
transport: Optional[Union[str, {{ service.name }}Transport]] = None,
client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand All @@ -230,7 +230,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
transport (Union[str, {{ service.name }}Transport]): The
transport (Optional[Union[str, {{ service.name }}Transport]]): The
transport to use. If set to None, a transport is chosen
automatically.
{% if 'rest' in opts.transport and not opts.rest_numeric_enums %}
Expand Down Expand Up @@ -340,17 +340,17 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
def {{ method.name|snake_case }}(self,
{% endif %}{# Extended Operations LRO #}
{% if not method.client_streaming %}
request: Union[{{ method.input.ident }}, dict] = None,
request: Optional[Union[{{ method.input.ident }}, dict]] = None,
*,
{% for field in method.flattened_fields.values() %}
{{ field.name }}: {{ field.ident }} = None,
{{ field.name }}: Optional[{{ field.ident }}] = None,
{% endfor %}
{% else %}
requests: Iterator[{{ method.input.ident }}] = None,
requests: Optional[Iterator[{{ method.input.ident }}]] = None,
*,
{% endif %}
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
{% if not method.server_streaming %}
) -> {{ method.client_output.ident }}:
Expand All @@ -361,7 +361,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

Args:
{% if not method.client_streaming %}
request (Union[{{ method.input.ident.sphinx }}, dict]):
request (Union[{{ method.input.ident.sphinx }}, dict, None]):
The request object.{{ " " }}
{{- method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{% for key, field in method.flattened_fields.items() %}
Expand Down Expand Up @@ -516,10 +516,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% if opts.add_iam_methods %}
def set_iam_policy(
self,
request: iam_policy_pb2.SetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Sets the IAM access control policy on the specified function.
Expand Down Expand Up @@ -633,10 +633,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

def get_iam_policy(
self,
request: iam_policy_pb2.GetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Gets the IAM access control policy for a function.
Expand Down Expand Up @@ -750,10 +750,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

def test_iam_permissions(
self,
request: iam_policy_pb2.TestIamPermissionsRequest = None,
request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> iam_policy_pb2.TestIamPermissionsResponse:
r"""Tests the specified IAM permissions against the IAM access control
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class {{ method.name }}Pager:
def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]:
return self._response.{{ method.paged_result_field.name }}.get(key)
{% else %}
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterator') }}:
def __iter__(self) -> {{ method.paged_result_field.ident | replace('MutableSequence', 'Iterator') }}:
for page in self.pages:
yield from page.{{ method.paged_result_field.name }}
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class {{ service.name }}Transport(abc.ABC):
def __init__(
self, *,
host: str = DEFAULT_HOST,
credentials: ga_credentials.Credentials = None,
credentials: Optional[ga_credentials.Credentials] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):

def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials = None,
credentials_file: str = None,
scopes: Sequence[str] = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
credentials: Optional[ga_credentials.Credentials] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
channel: Optional[grpc.Channel] = None,
api_mtls_endpoint: Optional[str] = None,
client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None,
ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None,
client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
Expand Down Expand Up @@ -186,8 +186,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
@classmethod
def create_channel(cls,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials = None,
credentials_file: str = None,
credentials: Optional[ga_credentials.Credentials] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
**kwargs) -> grpc.Channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{# TODO(yon-mg): handle mtls stuff if that is relevant for rest transport #}
def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials=None,
credentials_file: str=None,
scopes: Sequence[str]=None,
client_cert_source_for_mtls: Callable[[
], Tuple[bytes, bytes]]=None,
quota_project_id: Optional[str]=None,
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
credentials: Optional[ga_credentials.Credentials] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
client_cert_source_for_mtls: Optional[Callable[[
], Tuple[bytes, bytes]]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
url_scheme: str = 'https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.
Expand Down Expand Up @@ -282,9 +282,9 @@ class {{service.name}}RestTransport({{service.name}}Transport):

def __call__(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming.ResponseIterator{% endif %}{% endif %}:
{% if method.http_options and not method.client_streaming %}
r"""Call the {{- ' ' -}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

{% with p = proto.disambiguate('proto') %}
{% if proto.messages|length or proto.all_enums|length %}
from typing import MutableMapping, MutableSequence

import proto{% if p != 'proto' %} as {{ p }}{% endif %} # type: ignore
{% endif %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class {{ message.name }}({{ p }}.Message):
{% for field in message.fields.values() %}
{% if field.map %}
{% with key_field = field.message.fields['key'], value_field = field.message.fields['value'] %}
{{ field.name }} = {{ p }}.MapField(
{{ field.name }}: MutableMapping[{{ key_field.type.ident.rel(message.ident) }}, {{ value_field.type.ident.rel(message.ident) }}] = {{ p }}.MapField(
{{ p }}.{{ key_field.proto_type }},
{{ p }}.{{ value_field.proto_type }},
number={{ field.number }},
Expand All @@ -61,7 +61,7 @@ class {{ message.name }}({{ p }}.Message):
)
{% endwith %}
{% else %}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{ field.name }}: {% if field.is_primitive %}{{ field.ident }}{% else %}{% if field.repeated %}MutableSequence[{% endif %}{{ field.type.ident.rel(message.ident) }}{% if field.repeated %}]{% endif %}{% endif %} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{ p }}.{{ field.proto_type }},
number={{ field.number }},
{% if field.proto3_optional %}
Expand Down
4 changes: 2 additions & 2 deletions gapic/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import pathlib
import typing
from typing import Any, DefaultDict, Dict, Mapping, Tuple
from typing import Any, DefaultDict, Dict, Mapping, Optional, Tuple
from hashlib import sha256
from collections import OrderedDict, defaultdict
from gapic.samplegen_utils.utils import coerce_response_name, is_valid_sample_cfg, render_format_string
Expand Down Expand Up @@ -362,7 +362,7 @@ def _get_file(
return {fn: cgr_file}

def _get_filename(
self, template_name: str, *, api_schema: api.API, context: dict = None,
self, template_name: str, *, api_schema: api.API, context: Optional[dict] = None,
) -> str:
"""Return the appropriate output filename for this template.

Expand Down
4 changes: 2 additions & 2 deletions gapic/samplegen/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
import time
from typing import Tuple
from typing import Optional, Tuple

from gapic.samplegen_utils import (types, yaml)
from gapic.utils import case
Expand Down Expand Up @@ -45,7 +45,7 @@ def generate(
api_schema,
*,
environment: yaml.Map = PYTHON3_ENVIRONMENT,
manifest_time: int = None
manifest_time: Optional[int] = None
) -> Tuple[str, yaml.Doc]:
"""Generate a samplegen manifest for use by sampletest

Expand Down
Loading