diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index aa449a810..238e7e13a 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -24,6 +24,47 @@ from google.ai import generativelanguage as glm +Type = glm.Type + +TypeOptions = Union[int, str, Type] + +_TYPE_TYPE: dict[TypeOptions, Type] = { + Type.TYPE_UNSPECIFIED: Type.TYPE_UNSPECIFIED, + 0: Type.TYPE_UNSPECIFIED, + "type_unspecified": Type.TYPE_UNSPECIFIED, + "unspecified": Type.TYPE_UNSPECIFIED, + Type.STRING: Type.STRING, + 1: Type.STRING, + "type_string": Type.STRING, + "string": Type.STRING, + Type.NUMBER: Type.NUMBER, + 2: Type.NUMBER, + "type_number": Type.NUMBER, + "number": Type.NUMBER, + Type.INTEGER: Type.INTEGER, + 3: Type.INTEGER, + "type_integer": Type.INTEGER, + "integer": Type.INTEGER, + Type.BOOLEAN: Type.BOOLEAN, + 4: Type.INTEGER, + "type_boolean": Type.BOOLEAN, + "boolean": Type.BOOLEAN, + Type.ARRAY: Type.ARRAY, + 5: Type.ARRAY, + "type_array": Type.ARRAY, + "array": Type.ARRAY, + Type.OBJECT: Type.OBJECT, + 6: Type.OBJECT, + "type_object": Type.OBJECT, + "object": Type.OBJECT, +} + + +def to_type(x: TypeOptions) -> Type: + if isinstance(x, str): + x = x.lower() + return _TYPE_TYPE[x] + def _generate_schema( f: Callable[..., Any], @@ -115,7 +156,7 @@ def _generate_schema( return schema -def _rename_schema_fields(schema): +def _rename_schema_fields(schema: dict[str, Any]): if schema is None: return schema @@ -123,7 +164,10 @@ def _rename_schema_fields(schema): type_ = schema.pop("type", None) if type_ is not None: - schema["type_"] = type_.upper() + schema["type_"] = type_ + type_ = schema.get("type_", None) + if type_ is not None: + schema["type_"] = to_type(type_) format_ = schema.pop("format", None) if format_ is not None: diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index fb550834f..7c30f1363 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -16,13 +16,14 @@ import collections import contextlib -from collections.abc import Iterable, AsyncIterable +import sys +from collections.abc import Iterable, AsyncIterable, Mapping import dataclasses import itertools import json import sys import textwrap -from typing import Union +from typing import Union, Any from typing_extensions import TypedDict import google.protobuf.json_format @@ -30,6 +31,7 @@ from google.ai import generativelanguage as glm from google.generativeai import string_utils +from google.generativeai.responder import _rename_schema_fields __all__ = [ "AsyncGenerateContentResponse", @@ -81,6 +83,7 @@ class GenerationConfigDict(TypedDict, total=False): max_output_tokens: int temperature: float response_mime_type: str + response_schema: glm.Schema | Mapping[str, Any] # fmt: off @dataclasses.dataclass @@ -147,6 +150,10 @@ class GenerationConfig: Supported mimetype: `text/plain`: (default) Text output. `application/json`: JSON response in the candidates. + + response_schema: + Optional. Specifies the format of the JSON requested if response_mime_type is + `application/json`. """ candidate_count: int | None = None @@ -156,21 +163,41 @@ class GenerationConfig: top_p: float | None = None top_k: int | None = None response_mime_type: str | None = None + response_schema: glm.Schema | Mapping[str, Any] | None = None GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] +def _normalize_schema(generation_config): + # Convert response_schema to glm.Schema for request + response_schema = generation_config.get("response_schema", None) + if response_schema is None: + return + if isinstance(response_schema, glm.Schema): + return + response_schema = _rename_schema_fields(response_schema) + generation_config["response_schema"] = glm.Schema(response_schema) + + def to_generation_config_dict(generation_config: GenerationConfigType): if generation_config is None: return {} elif isinstance(generation_config, glm.GenerationConfig): - return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error + schema = generation_config.response_schema + generation_config = type(generation_config).to_dict( + generation_config + ) # pytype: disable=attribute-error + generation_config["response_schema"] = schema + return generation_config elif isinstance(generation_config, GenerationConfig): generation_config = dataclasses.asdict(generation_config) + _normalize_schema(generation_config) return {key: value for key, value in generation_config.items() if value is not None} elif hasattr(generation_config, "keys"): - return dict(generation_config) + generation_config = dict(generation_config) + _normalize_schema(generation_config) + return generation_config else: raise TypeError( "Did not understand `generation_config`, expected a `dict` or" diff --git a/tests/test_generation.py b/tests/test_generation.py index 7d154d186..6d559999f 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -561,6 +561,53 @@ def test_repr_for_generate_content_response_from_iterator(self): ) self.assertEqual(expected, result) + @parameterized.named_parameters( + [ + "glm.GenerationConfig", + glm.GenerationConfig( + temperature=0.1, + stop_sequences=["end"], + response_mime_type="application/json", + response_schema=glm.Schema( + type="STRING", format="float", description="This is an example schema." + ), + ), + ], + [ + "GenerationConfigDict", + { + "temperature": 0.1, + "stop_sequences": ["end"], + "response_mime_type": "application/json", + "response_schema": glm.Schema( + type="STRING", format="float", description="This is an example schema." + ), + }, + ], + [ + "GenerationConfig", + generation_types.GenerationConfig( + temperature=0.1, + stop_sequences=["end"], + response_mime_type="application/json", + response_schema=glm.Schema( + type="STRING", format="float", description="This is an example schema." + ), + ), + ], + ) + def test_response_schema(self, config): + gd = generation_types.to_generation_config_dict(config) + self.assertIsInstance(gd, dict) + self.assertEqual(gd["temperature"], 0.1) + self.assertEqual(gd["stop_sequences"], ["end"]) + self.assertEqual(gd["response_mime_type"], "application/json") + actual = gd["response_schema"] + expected = glm.Schema( + type="STRING", format="float", description="This is an example schema." + ) + self.assertEqual(actual, expected) + if __name__ == "__main__": absltest.main()