From 9ad98ca6833f1b280bf3c04c858f92276d59ffbe Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Tue, 18 Jan 2022 08:36:04 -0800 Subject: [PATCH] fix: only set unset fields if they are query params (#1130) --- .../services/%service/transports/rest.py.j2 | 4 +- .../%name_%version/%sub/test_%service.py.j2 | 26 +++++-- gapic/schema/wrappers.py | 2 +- .../services/%service/transports/rest.py.j2 | 2 +- .../%name_%version/%sub/test_%service.py.j2 | 75 ++++--------------- noxfile.py | 4 - .../test_multiple_required_fields.proto | 11 ++- tests/unit/schema/wrappers/test_method.py | 4 + 8 files changed, 48 insertions(+), 80 deletions(-) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 index 3a47d1e363..488646be12 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 @@ -182,8 +182,8 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if not (method.server_streaming or method.client_streaming) %} {% if method.input.required_fields %} __REQUIRED_FIELDS_DEFAULT_VALUES = { - {% for req_field in method.input.required_fields if req_field.is_primitive %} - "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value or 0 }}{% endif %},{# default is str #} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} + "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #} {% endfor %} } diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 38d91f3ae9..af7d28335d 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -994,7 +994,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide )) # verify fields with default values are dropped - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} {% set field_name = req_field.name | camel_case %} assert "{{ field_name }}" not in jsonified_request {% endfor %} @@ -1003,23 +1003,32 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide jsonified_request.update(unset_fields) # verify required fields with default values are now present - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} {% set field_name = req_field.name | camel_case %} assert "{{ field_name }}" in jsonified_request assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"] {% endfor %} - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} {% set field_name = req_field.name | camel_case %} {% set mock_value = req_field.primitive_mock_as_str() %} + {% if method.query_params %} + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param|camel_case }}", {% endfor %})) + {% endif %} jsonified_request["{{ field_name }}"] = {{ mock_value }} {% endfor %} unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request) + {% if method.query_params %} + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param}}", +{% endfor %})) + {% endif %} jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} {% set field_name = req_field.name | camel_case %} {% set mock_value = req_field.primitive_mock_as_str() %} assert "{{ field_name }}" in jsonified_request @@ -1080,7 +1089,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide {% endif %} expected_params = [ - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} ( "{{ req_field.name | camel_case }}", {% if req_field.field_pb.type == 9 %} @@ -1095,6 +1104,13 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide assert expected_params == actual_params +def test_{{ method_name }}_rest_unset_required_fields(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({}) + assert set(unset_fields) == (set(({% for param in method.query_params %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{param.name|camel_case}}", {% endfor %}))) + + {% endif %}{# required_fields #} diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index c4c9e6bec0..67fa3afe23 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -1026,7 +1026,7 @@ def path_params(self) -> Sequence[str]: if self.http_opt is None: return [] - pattern = r'\{(\w+)\}' + pattern = r'\{(\w+)(?:=.+?)?\}' return re.findall(pattern, self.http_opt['url']) @property diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index bfec823923..488646be12 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -182,7 +182,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if not (method.server_streaming or method.client_streaming) %} {% if method.input.required_fields %} __REQUIRED_FIELDS_DEFAULT_VALUES = { - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #} {% endfor %} } diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 4cfacf699e..6da9da3efa 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1287,65 +1287,8 @@ def test_{{ method_name }}_raw_page_lro(): {% endfor %} {# method in methods for grpc #} {% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %} -{# TODO(kbandes): remove this if condition when streaming are supported. #} -{% if not (method.server_streaming or method.client_streaming) %} -@pytest.mark.parametrize("request_type", [ - {{ method.input.ident }}, - dict, -]) -def test_{{ method_name }}_rest(request_type, transport: str = 'rest'): - client = {{ service.client_name }}( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - # Send a request that will satisfy transcoding - request = {{ method.input.ident }}({{ method.http_options[0].sample_request(method) }}) - {% if method.client_streaming %} - requests = [request] - {% endif %} - - - with mock.patch.object(type(client.transport._session), 'request') as req: - {% if method.void %} - return_value = None - {% elif method.lro %} - return_value = operations_pb2.Operation(name='operations/spam') - {% elif method.server_streaming %} - return_value = iter([{{ method.output.ident }}()]) - {% else %} - return_value = {{ method.output.ident }}( - {% for field in method.output.fields.values() | rejectattr('message')%} - {% if not field.oneof or field.proto3_optional %} - {{ field.name }}={{ field.mock_value }}, - {% endif %}{% endfor %} - {# This is a hack to only pick one field #} - {% for oneof_fields in method.output.oneof_fields().values() %} - {% with field = oneof_fields[0] %} - {{ field.name }}={{ field.mock_value }}, - {% endwith %} - {% endfor %} - ) - {% endif %} - req.return_value = Response() - req.return_value.status_code = 500 - req.return_value.request = PreparedRequest() - {% if method.void %} - json_return_value = '' - {% else %} - json_return_value = {{ method.output.ident }}.to_json(return_value) - {% endif %} - req.return_value._content = json_return_value.encode("UTF-8") - with pytest.raises(core_exceptions.GoogleAPIError): - # We only care that the correct exception is raised when putting - # the request over the wire, so an empty request is fine. - {% if method.client_streaming %} - client.{{ method_name }}(iter([requests])) - {% else %} - client.{{ method_name }}(request) - {% endif %} - - {# TODO(kbandes): remove this if condition when lro and streaming are supported. #} +{% if not (method.server_streaming or method.client_streaming) %} @pytest.mark.parametrize("request_type", [ {{ method.input.ident }}, dict, @@ -1458,7 +1401,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide )) # verify fields with default values are dropped - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} {% set field_name = req_field.name | camel_case %} assert "{{ field_name }}" not in jsonified_request {% endfor %} @@ -1467,7 +1410,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide jsonified_request.update(unset_fields) # verify required fields with default values are now present - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} {% set field_name = req_field.name | camel_case %} assert "{{ field_name }}" in jsonified_request assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"] @@ -1480,6 +1423,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide {% endfor %} unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request) + {% if method.query_params %} + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param}}", {% endfor %})) + {% endif %} jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -1544,7 +1491,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide {% endif %} expected_params = [ - {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %} ( "{{ req_field.name | camel_case }}", {% if req_field.field_pb.type == 9 %} @@ -1559,6 +1506,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide assert expected_params == actual_params +def test_{{ method_name }}_rest_unset_required_fields(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({}) + assert set(unset_fields) == (set(({% for param in method.query_params %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{ param.name|camel_case }}", {% endfor %}))) + {% endif %}{# required_fields #} diff --git a/noxfile.py b/noxfile.py index f281461e80..6154ea94bb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -103,9 +103,6 @@ def __call__(self, frag): f"--python_gapic_opt=transport=grpc+rest,python-gapic-templates={templates}{maybe_old_naming}", ] - if self.use_ads_templates: - session_args.extend([]) - outputs.append( self.session.run(*session_args, str(frag), external=True, silent=True,) ) @@ -114,7 +111,6 @@ def __call__(self, frag): # Note: install into the tempdir to prevent issues # with running pip concurrently. self.session.install(tmp_dir, "-e", ".", "-t", tmp_dir, "-qqq") - # Run the fragment's generated unit tests. # Don't bother parallelizing them: we already parallelize # the fragments, and there usually aren't too many tests per fragment. diff --git a/tests/fragments/test_multiple_required_fields.proto b/tests/fragments/test_multiple_required_fields.proto index 8e4cdfe157..55a207b40a 100644 --- a/tests/fragments/test_multiple_required_fields.proto +++ b/tests/fragments/test_multiple_required_fields.proto @@ -30,14 +30,13 @@ service MultipleRequiredFields { } } -message Description { - string description = 1; -} - message MethodRequest { string kingdom = 1 [(google.api.field_behavior) = REQUIRED]; string phylum = 2 [(google.api.field_behavior) = REQUIRED]; - Description description = 3 [(google.api.field_behavior) = REQUIRED]; + string name = 3 [(google.api.field_behavior) = REQUIRED]; + int32 armor_class = 4 [(google.api.field_behavior) = REQUIRED]; } -message MethodResponse{} \ No newline at end of file +message MethodResponse{ + string text = 1; +} \ No newline at end of file diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 889dc629a3..814893c39c 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -325,6 +325,10 @@ def test_method_path_params(): method = make_method('DoSomething', http_rule=http_rule) assert method.path_params == ['project'] + http_rule2 = http_pb2.HttpRule(post='/v1beta1/{name=rooms/*/blurbs/*}') + method2 = make_method("DoSomething", http_rule=http_rule2) + assert method2.path_params == ["name"] + def test_method_path_params_no_http_rule(): method = make_method('DoSomething')