diff --git a/gapic/samplegen/samplegen.py b/gapic/samplegen/samplegen.py index 12ebd93d04..5ea5293d62 100644 --- a/gapic/samplegen/samplegen.py +++ b/gapic/samplegen/samplegen.py @@ -362,10 +362,13 @@ def _normal_request_setup(self, base_param_to_attrs, val, request, field): if attr.message: base = attr.message elif attr.enum: - # A little bit hacky, but 'values' is a list, and this is the easiest - # way to verify that the value is a valid enum variant. - witness = any(e.name == val for e in attr.enum.values) - if not witness: + valid_enum_names = set(e.name for e in attr.enum.values) + if isinstance(val, List): + valid = valid_enum_names.issuperset(val) + else: + valid = (val in valid_enum_names) + + if not valid: raise types.InvalidEnumVariant( "Invalid variant for enum {}: '{}'".format(attr, val) ) @@ -936,6 +939,19 @@ def parse_handwritten_specs(sample_configs: Sequence[str]) -> Generator[Dict[str yield spec +def _field_value_from_field(field: wrappers.Field) -> Any: + if field.is_primitive: + field_value = field.mock_value_original_type + elif field.enum: + # Choose the last enum value in the list since index 0 is often "unspecified" + field_value = field.enum.values[-1].name + + if field.repeated: + return [field_value] + else: + return field_value + + def generate_request_object(api_schema: api.API, service: wrappers.Service, message: wrappers.MessageType, field_name_prefix: str = ""): """Generate dummy input for a given message. @@ -969,13 +985,9 @@ def generate_request_object(api_schema: api.API, service: wrappers.Service, mess field_name = ".".join([field_name_prefix, field.name]).lstrip('.') # TODO(busunkim): Properly handle map fields - if field.is_primitive: - request.append( - {"field": field_name, "value": field.mock_value_original_type}) - elif field.enum: - # Choose the last enum value in the list since index 0 is often "unspecified" + if field.is_primitive or field.enum: request.append( - {"field": field_name, "value": field.enum.values[-1].name}) + {"field": field_name, "value": _field_value_from_field(field)}) else: # This is a message type, recurse # TODO(busunkim): Some real world APIs have @@ -1014,7 +1026,7 @@ def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, A spec = { "rpc": rpc_name, "transport": transport, - # `request` and `response` is populated in `preprocess_sample` + # `request` and `response` are populated in `preprocess_sample` "service": f"{api_schema.naming.proto_package}.{service_name}", "region_tag": region_tag, "description": f"Snippet for {utils.to_snake_case(rpc_name)}"