Skip to content

Commit

Permalink
fix: ensure rest unit tests have complete coverage (#1098)
Browse files Browse the repository at this point in the history
* fix: rest paging and lro client tests weren't working.

* fix: fix coverage gaps in rest unit tests.

* fix: refactor required fields code to move update out of static method.

* fix: test that api method with required fields handles them correctly.

* fix: removed extra parens from an expression in a test.

Co-authored-by: Kenneth Bandes <kbandes@google.com>
Co-authored-by: Anthonios Partheniou <partheniou@google.com>
  • Loading branch information
3 people authored Dec 2, 2021
1 parent 956078f commit 0705d9c
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,22 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}{# service.has_lro #}
{% for method in service.methods.values() %}
{%- if method.http_options and not (method.server_streaming or method.client_streaming) %}

{% if method.input.required_fields %}
__{{ method.name | snake_case }}_required_fields_default_values = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %}{# default is str #}
{% endfor %}
}


@staticmethod
def _{{ method.name | snake_case }}_get_unset_required_fields(message_dict):
return {k: v for k, v in {{service.name}}RestTransport.__{{ method.name | snake_case }}_required_fields_default_values.items() if k not in message_dict}


{% endif %}{# required fields #}

def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -206,21 +222,6 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endfor %}
]

{% if method.input.required_fields %}
required_fields = [
# (snake_case_name, camel_case_name)
{% for req_field in method.input.required_fields %}
{% if req_field.is_primitive %}
(
"{{ req_field.name | snake_case }}",
"{{ req_field.name | camel_case }}"
),
{% endif %}{# is primitive #}
{% endfor %}{# required fields #}
]

{% endif %}

request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -254,16 +255,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
))

{% if method.input.required_fields %}
# Ensure required fields have values in query_params.
# If a required field has a default value, it can get lost
# by the to_json call above.
orig_query_params = transcoded_request["query_params"]
for snake_case_name, camel_case_name in required_fields:
if snake_case_name in orig_query_params:
if camel_case_name not in query_params:
query_params[camel_case_name] = orig_query_params[snake_case_name]

{% endif %}
query_params.update(self._{{ method.name | snake_case }}_get_unset_required_fields(query_params))
{% endif %}{# required fields #}

# Send the request
headers = dict(metadata)
Expand Down
164 changes: 150 additions & 14 deletions gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import mock

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down Expand Up @@ -1187,6 +1188,7 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
{% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %}
{# Cheeser assertion to force code coverage for bad paginated methods #}
assert response.raw_page is response

{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -1210,6 +1212,130 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
{% endif %}


{% if method.input.required_fields %}
def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}):
transport_class = transports.{{ service.rest_transport_name }}

request_init = {}
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% if req_field.field_pb.default_value is string %}
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
{% else %}
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
jsonified_request = json.loads(request_type.to_json(
request,
including_default_value_fields=False,
use_integers_for_enums=False
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}

unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
{% endfor %}

{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with non-default values are left alone
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == {{ mock_value }}
{% endfor %}


client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport='rest',
)
request = request_type(request_init)

# Designate an appropriate value for the returned response.
{% if method.void %}
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
# Mock the http request call within the method and fake a response.
with mock.patch.object(Session, 'request') as req:
# We need to mock transcode() because providing default values
# for required fields will fail the real version if the http_options
# expect actual values for those fields.
with mock.patch.object(path_template, 'transcode') as transcode:
# A uri without fields and an empty body will force all the
# request fields to show up in the query_params.
transcode_result = {
'uri': 'v1/sample_method',
'method': "{{ method.http_options[0].method }}",
'query_params': request_init,
}
{% if method.http_options[0].body %}
transcode_result['body'] = {}
{% endif %}
transcode.return_value = transcode_result

response_value = Response()
response_value.status_code = 200
{% if method.void %}
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
{% else %}
response = client.{{ method_name }}(request)
{% endif %}

expected_params = [
{% for req_field in method.input.required_fields if req_field.is_primitive %}
(
"{{ req_field.name }}",
{% if req_field.field_pb.default_value is string %}
"{{ req_field.field_pb.default_value }}"
{% else %}
{{ req_field.field_pb.default_value }}
{% endif %}{# default is str #}
)
{% endfor %}
]
actual_params = req.call_args.kwargs['params']
assert expected_params == actual_params


{% endif %}{# required_fields #}


def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand Down Expand Up @@ -1325,9 +1451,10 @@ def test_{{ method_name }}_rest_flattened_error(transport: str = 'rest'):


{% if method.paged_result_field %}
def test_{{ method_name }}_rest_pager():
def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Mock the http request call within the method and fake a response.
Expand Down Expand Up @@ -1446,25 +1573,35 @@ def test_{{ method_name }}_rest_error():
credentials=ga_credentials.AnonymousCredentials(),
transport='rest'
)
{%- if not method.http_options %}
# Since a `google.api.http` annotation is required for using a rest transport
# method, this should error.
with pytest.raises(RuntimeError) as runtime_error:
client.{{ method_name }}({})
assert ('Cannot define a method without a valid `google.api.http` annotation.'
in str(runtime_error.value))
{%- else %}

# TODO(yon-mg): Remove when this method has a working implementation
# or testing straegy
with pytest.raises(NotImplementedError):
client.{{ method_name }}({})

{%- endif %}

{% endif %}{% endwith %}{# method_name #}
{% endif %}{# not streaming #}{% endwith %}{# method_name #}

{% endfor -%} {#- method in methods for rest #}

{% for method in service.methods.values() if 'rest' in opts.transport and
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
def test_{{ method_name }}_rest_error():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport='rest'
)
# Since a `google.api.http` annotation is required for using a rest transport
# method, this should error.
with pytest.raises(RuntimeError) as runtime_error:
client.{{ method_name }}({})
assert ("Cannot define a method without a valid 'google.api.http' annotation."
in str(runtime_error.value))


{% endwith %}{# method_name #}
{% endfor %}{# for methods without http_options #}

def test_credentials_transport_error():
# It is an error to provide credentials and a transport instance.
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(
Expand Down Expand Up @@ -1758,8 +1895,7 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl
mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)


{# TODO(kbandes): re-enable this code when LRO is implmented for rest #}
{% if False and service.has_lro -%}
{% if service.has_lro -%}
def test_{{ service.name|snake_case }}_rest_lro_client():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand All @@ -1770,7 +1906,7 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
# Ensure that we have a api-core operations client.
assert isinstance(
transport.operations_client,
operations_v1.OperationsClient,
operations_v1.AbstractOperationsClient,
)

# Ensure that subsequent calls to the property send the exact same object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down

0 comments on commit 0705d9c

Please sign in to comment.