Skip to content

Commit

Permalink
fix: snippetgen handling of repeated enum field
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcology committed Sep 9, 2022
1 parent 9e9971f commit 54f481a
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions gapic/samplegen/samplegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,13 @@ def _normal_request_setup(self, base_param_to_attrs, val, request, field):
if attr.message:
base = attr.message
elif attr.enum:
# A little bit hacky, but 'values' is a list, and this is the easiest
# way to verify that the value is a valid enum variant.
witness = any(e.name == val for e in attr.enum.values)
if not witness:
valid_enum_names = set(e.name for e in attr.enum.values)
if isinstance(val, List):
valid = valid_enum_names.issuperset(val)
else:
valid = (val in valid_enum_names)

if not valid:
raise types.InvalidEnumVariant(
"Invalid variant for enum {}: '{}'".format(attr, val)
)
Expand Down Expand Up @@ -936,6 +939,19 @@ def parse_handwritten_specs(sample_configs: Sequence[str]) -> Generator[Dict[str
yield spec


def _field_value_from_field(field: wrappers.Field) -> Any:
if field.is_primitive:
field_value = field.mock_value_original_type
elif field.enum:
# Choose the last enum value in the list since index 0 is often "unspecified"
field_value = field.enum.values[-1].name

if field.repeated:
return [field_value]
else:
return field_value


def generate_request_object(api_schema: api.API, service: wrappers.Service, message: wrappers.MessageType, field_name_prefix: str = ""):
"""Generate dummy input for a given message.
Expand Down Expand Up @@ -969,13 +985,9 @@ def generate_request_object(api_schema: api.API, service: wrappers.Service, mess
field_name = ".".join([field_name_prefix, field.name]).lstrip('.')

# TODO(busunkim): Properly handle map fields
if field.is_primitive:
request.append(
{"field": field_name, "value": field.mock_value_original_type})
elif field.enum:
# Choose the last enum value in the list since index 0 is often "unspecified"
if field.is_primitive or field.enum:
request.append(
{"field": field_name, "value": field.enum.values[-1].name})
{"field": field_name, "value": _field_value_from_field(field)})
else:
# This is a message type, recurse
# TODO(busunkim): Some real world APIs have
Expand Down Expand Up @@ -1014,7 +1026,7 @@ def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, A
spec = {
"rpc": rpc_name,
"transport": transport,
# `request` and `response` is populated in `preprocess_sample`
# `request` and `response` are populated in `preprocess_sample`
"service": f"{api_schema.naming.proto_package}.{service_name}",
"region_tag": region_tag,
"description": f"Snippet for {utils.to_snake_case(rpc_name)}"
Expand Down

0 comments on commit 54f481a

Please sign in to comment.