diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index aecda19f44..24422244a9 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -925,9 +925,21 @@ def query_params(self) -> Set[str]: return set(self.input.fields) - params + @property + def body_fields(self) -> Mapping[str, Field]: + bindings = self.http_options + if bindings and bindings[0].body and bindings[0].body != "*": + return self._fields_mapping([bindings[0].body]) + return {} + # TODO(yon-mg): refactor as there may be more than one method signature @utils.cached_property def flattened_fields(self) -> Mapping[str, Field]: + signatures = self.options.Extensions[client_pb2.method_signature] + return self._fields_mapping(signatures) + + # TODO(yon-mg): refactor as there may be more than one method signature + def _fields_mapping(self, signatures) -> Mapping[str, Field]: """Return the signature defined for this method.""" cross_pkg_request = self.input.ident.package != self.ident.package @@ -946,7 +958,6 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: yield name, field - signatures = self.options.Extensions[client_pb2.method_signature] answer: Dict[str, Field] = collections.OrderedDict( name_and_field for sig in signatures diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 02bfe76135..809f728dd1 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -306,9 +306,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - {% if "grpc" in opts.transport %} always_use_jwt_access=True, - {% endif %} ) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 994e30d9e4..d85695b76f 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -1,14 +1,16 @@ -from google.auth.transport.requests import AuthorizedSession +from google.auth.transport.requests import AuthorizedSession # type: ignore import json # type: ignore import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.api_core import exceptions as core_exceptions # type: ignore -from google.api_core import retry as retries # type: ignore -from google.api_core import rest_helpers # type: ignore -from google.api_core import path_template # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import rest_helpers # type: ignore +from google.api_core import path_template # type: ignore +from google.api_core import gapic_v1 # type: ignore +{% if service.has_lro %} from google.api_core import operations_v1 +{% endif %} from requests import __version__ as requests_version from typing import Callable, Dict, Optional, Sequence, Tuple, Union import warnings diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index d4ec2c3142..53ef176529 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1106,7 +1106,14 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type ) # send a request that will satisfy transcoding - request = request_type({{ method.http_options[0].sample_request}}) + request_init = {{ method.http_options[0].sample_request}} + {% for field in method.body_fields.values() %} + {% if not field.oneof or field.proto3_optional %} + {# ignore oneof fields that might conflict with sample_request #} + request_init["{{ field.name }}"] = {{ field.mock_value }} + {% endif %} + {% endfor %} + request = request_type(request_init) {% if method.client_streaming %} requests = [request] {% endif %} @@ -2419,6 +2426,7 @@ async def test_test_iam_permissions_from_dict_async(): {% endif %} +{% if 'grpc' in opts.transport %} @pytest.mark.asyncio async def test_transport_close_async(): client = {{ service.async_client_name }}( @@ -2429,6 +2437,7 @@ async def test_transport_close_async(): async with client: close.assert_not_called() close.assert_called_once() +{% endif %} def test_transport_close(): transports = { diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index c6a81d9128..d377375036 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -330,6 +330,35 @@ def test_method_path_params_no_http_rule(): assert method.path_params == [] +def test_body_fields(): + http_rule = http_pb2.HttpRule( + post='/v1/{arms_shape=arms/*}/squids', + body='mantle' + ) + + mantle_stuff = make_field(name='mantle_stuff', type=9) + message = make_message('Mantle', fields=(mantle_stuff,)) + mantle = make_field('mantle', type=11, type_name='Mantle', message=message) + arms_shape = make_field('arms_shape', type=9) + input_message = make_message('Squid', fields=(mantle, arms_shape)) + method = make_method( + 'PutSquid', input_message=input_message, http_rule=http_rule) + assert set(method.body_fields) == {'mantle'} + mock_value = method.body_fields['mantle'].mock_value + assert mock_value == "baz.Mantle(mantle_stuff='mantle_stuff_value')" + + +def test_body_fields_no_body(): + http_rule = http_pb2.HttpRule( + post='/v1/{arms_shape=arms/*}/squids', + ) + + method = make_method( + 'PutSquid', http_rule=http_rule) + + assert not method.body_fields + + def test_method_http_options(): verbs = [ 'get', @@ -363,7 +392,7 @@ def test_method_http_options_no_http_rule(): assert method.path_params == [] -def test_method_http_options_body(): +def test_method_http_options_body_star(): http_rule = http_pb2.HttpRule( post='/v1/{parent=projects/*}/topics', body='*' @@ -376,6 +405,19 @@ def test_method_http_options_body(): }] +def test_method_http_options_body_field(): + http_rule = http_pb2.HttpRule( + post='/v1/{parent=projects/*}/topics', + body='body_field' + ) + method = make_method('DoSomething', http_rule=http_rule) + assert [dataclasses.asdict(http) for http in method.http_options] == [{ + 'method': 'post', + 'uri': '/v1/{parent=projects/*}/topics', + 'body': 'body_field' + }] + + def test_method_http_options_additional_bindings(): http_rule = http_pb2.HttpRule( post='/v1/{parent=projects/*}/topics',