Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure rest unit tests have complete coverage #1098

Merged
merged 6 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,22 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}{# service.has_lro #}
{% for method in service.methods.values() %}
{%- if method.http_options and not (method.server_streaming or method.client_streaming) %}

{% if method.input.required_fields %}
__{{ method.name | snake_case }}_required_fields_default_values = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %}{# default is str #}
{% endfor %}
}


@staticmethod
def _{{ method.name | snake_case }}_get_unset_required_fields(message_dict):
return {k: v for k, v in {{service.name}}RestTransport.__{{ method.name | snake_case }}_required_fields_default_values.items() if k not in message_dict}


{% endif %}{# required fields #}

def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -206,21 +222,6 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endfor %}
]

{% if method.input.required_fields %}
required_fields = [
# (snake_case_name, camel_case_name)
{% for req_field in method.input.required_fields %}
{% if req_field.is_primitive %}
(
"{{ req_field.name | snake_case }}",
"{{ req_field.name | camel_case }}"
),
{% endif %}{# is primitive #}
{% endfor %}{# required fields #}
]

{% endif %}

request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -254,16 +255,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
))

{% if method.input.required_fields %}
# Ensure required fields have values in query_params.
# If a required field has a default value, it can get lost
# by the to_json call above.
orig_query_params = transcoded_request["query_params"]
for snake_case_name, camel_case_name in required_fields:
if snake_case_name in orig_query_params:
if camel_case_name not in query_params:
query_params[camel_case_name] = orig_query_params[snake_case_name]

{% endif %}
query_params.update(self._{{ method.name | snake_case }}_get_unset_required_fields(query_params))
{% endif %}{# required fields #}

# Send the request
headers = dict(metadata)
Expand Down
164 changes: 150 additions & 14 deletions gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import mock

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down Expand Up @@ -1187,6 +1188,7 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
{% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %}
{# Cheeser assertion to force code coverage for bad paginated methods #}
assert response.raw_page is response

{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -1210,6 +1212,130 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
{% endif %}


{% if method.input.required_fields %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be reasonable to have a test or two just for the hidden required fields update method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean. Which method is hidden?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, once we refactor as you describe above.
This is quite tricky to test, which is a lot of motivation for these changes. The problem is that typically, though not always, a required field is going to have an expected template in the http rule, so the default value will cause the transcoding to fail. This can be worked around by mocking the transcoding function, but it gets convoluted and ugly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what to do here. No code is un-covered. I would argue that testing this logic in the context of the api method itself doesn't necessarily add anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I've added logic to test the actual api method with default-valued required fields.

def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}):
transport_class = transports.{{ service.rest_transport_name }}

request_init = {}
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% if req_field.field_pb.default_value is string %}
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
{% else %}
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
jsonified_request = json.loads(request_type.to_json(
request,
including_default_value_fields=False,
use_integers_for_enums=False
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}

unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
{% endfor %}

{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with non-default values are left alone
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == {{ mock_value }}
{% endfor %}


client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport='rest',
)
request = request_type(request_init)

# Designate an appropriate value for the returned response.
{% if method.void %}
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
# Mock the http request call within the method and fake a response.
with mock.patch.object(Session, 'request') as req:
# We need to mock transcode() because providing default values
# for required fields will fail the real version if the http_options
# expect actual values for those fields.
with mock.patch.object(path_template, 'transcode') as transcode:
# A uri without fields and an empty body will force all the
# request fields to show up in the query_params.
transcode_result = {
'uri': 'v1/sample_method',
'method': "{{ method.http_options[0].method }}",
'query_params': request_init,
}
{% if method.http_options[0].body %}
transcode_result['body'] = {}
{% endif %}
transcode.return_value = transcode_result

response_value = Response()
response_value.status_code = 200
{% if method.void %}
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
{% else %}
response = client.{{ method_name }}(request)
{% endif %}

expected_params = [
{% for req_field in method.input.required_fields if req_field.is_primitive %}
(
"{{ req_field.name }}",
{% if req_field.field_pb.default_value is string %}
"{{ req_field.field_pb.default_value }}"
{% else %}
{{ req_field.field_pb.default_value }}
{% endif %}{# default is str #}
)
{% endfor %}
]
actual_params = req.call_args.kwargs['params']
assert expected_params == actual_params


{% endif %}{# required_fields #}


def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand Down Expand Up @@ -1325,9 +1451,10 @@ def test_{{ method_name }}_rest_flattened_error(transport: str = 'rest'):


{% if method.paged_result_field %}
def test_{{ method_name }}_rest_pager():
def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Mock the http request call within the method and fake a response.
Expand Down Expand Up @@ -1446,25 +1573,35 @@ def test_{{ method_name }}_rest_error():
credentials=ga_credentials.AnonymousCredentials(),
transport='rest'
)
{%- if not method.http_options %}
# Since a `google.api.http` annotation is required for using a rest transport
# method, this should error.
with pytest.raises(RuntimeError) as runtime_error:
client.{{ method_name }}({})
assert ('Cannot define a method without a valid `google.api.http` annotation.'
in str(runtime_error.value))
{%- else %}

# TODO(yon-mg): Remove when this method has a working implementation
# or testing straegy
with pytest.raises(NotImplementedError):
client.{{ method_name }}({})

{%- endif %}

{% endif %}{% endwith %}{# method_name #}
{% endif %}{# not streaming #}{% endwith %}{# method_name #}

{% endfor -%} {#- method in methods for rest #}

{% for method in service.methods.values() if 'rest' in opts.transport and
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
def test_{{ method_name }}_rest_error():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport='rest'
)
# Since a `google.api.http` annotation is required for using a rest transport
# method, this should error.
with pytest.raises(RuntimeError) as runtime_error:
client.{{ method_name }}({})
assert ("Cannot define a method without a valid 'google.api.http' annotation."
in str(runtime_error.value))


{% endwith %}{# method_name #}
{% endfor %}{# for methods without http_options #}

def test_credentials_transport_error():
# It is an error to provide credentials and a transport instance.
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(
Expand Down Expand Up @@ -1758,8 +1895,7 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl
mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)


{# TODO(kbandes): re-enable this code when LRO is implmented for rest #}
{% if False and service.has_lro -%}
{% if service.has_lro -%}
def test_{{ service.name|snake_case }}_rest_lro_client():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand All @@ -1770,7 +1906,7 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
# Ensure that we have a api-core operations client.
assert isinstance(
transport.operations_client,
operations_v1.OperationsClient,
operations_v1.AbstractOperationsClient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed?

Copy link
Contributor Author

@kbandes kbandes Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was at the request of software-dov. The change was made to api-core in a prior PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding related PRs: I'm not sure. I believe other people are making these changes and releasing them, I'm not sure of the timeline or PR numbers. The change to noxfile.py is small, but I won't know exactly how to do it until I know what the actual release numbers for showcase and api-core are.

)

# Ensure that subsequent calls to the property send the exact same object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down