Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agent): Implement more tolerant json_loads function #7016

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
LanguageModelClassification,
PromptStrategy,
)
from autogpt.core.prompting.utils import json_loads
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage,
ChatMessage,
ChatModelProvider,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import extract_dict_from_response
from autogpt.json_utils.utilities import extract_dict_from_response, json_loads
from autogpt.prompts.utils import format_numbered_list, indent


Expand Down Expand Up @@ -439,7 +439,7 @@ def extract_command(
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
assistant_reply_json["command"] = {
"name": assistant_reply.tool_calls[0].function.name,
"args": json.loads(assistant_reply.tool_calls[0].function.arguments),
"args": json_loads(assistant_reply.tool_calls[0].function.arguments),
}
try:
if not isinstance(assistant_reply_json, dict):
Expand Down
4 changes: 2 additions & 2 deletions autogpts/autogpt/autogpt/commands/image_gen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Commands to generate images based on text input"""

import io
import json
import logging
import time
import uuid
Expand All @@ -15,6 +14,7 @@
from autogpt.agents.agent import Agent
from autogpt.command_decorator import command
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads

COMMAND_CATEGORY = "text_to_image"
COMMAND_CATEGORY_TITLE = "Text to Image"
Expand Down Expand Up @@ -102,7 +102,7 @@ def generate_image_with_hf(prompt: str, output_file: Path, agent: Agent) -> str:
break
else:
try:
error = json.loads(response.text)
error = json_loads(response.text)
kcze marked this conversation as resolved.
Show resolved Hide resolved
if "estimated_time" in error:
delay = error["estimated_time"]
logger.debug(response.text)
Expand Down
3 changes: 2 additions & 1 deletion autogpts/autogpt/autogpt/commands/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from autogpt.agents.utils.exceptions import ConfigurationError
from autogpt.command_decorator import command
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads

COMMAND_CATEGORY = "web_search"
COMMAND_CATEGORY_TITLE = "Web Search"
Expand Down Expand Up @@ -134,7 +135,7 @@ def google(query: str, agent: Agent, num_results: int = 8) -> str | list[str]:

except HttpError as e:
# Handle errors in the API call
error_details = json.loads(e.content.decode())
error_details = json_loads(e.content.decode())
kcze marked this conversation as resolved.
Show resolved Hide resolved

# Check if the error is related to an invalid or missing API key
if error_details.get("error", {}).get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from autogpt.core.planning.schema import Task, TaskType
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
from autogpt.core.prompting.utils import json_loads, to_numbered_list
from autogpt.core.prompting.utils import to_numbered_list
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads
kcze marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
from autogpt.core.prompting.utils import json_loads
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from autogpt.core.planning.schema import Task
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
from autogpt.core.prompting.utils import json_loads, to_numbered_list
from autogpt.core.prompting.utils import to_numbered_list
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads

logger = logging.getLogger(__name__)

Expand Down
20 changes: 0 additions & 20 deletions autogpts/autogpt/autogpt/core/prompting/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import ast
import json


def to_numbered_list(
items: list[str], no_items_response: str = "", **template_args
) -> str:
Expand All @@ -11,19 +7,3 @@ def to_numbered_list(
)
else:
return no_items_response


def json_loads(json_str: str):
# TODO: this is a hack function for now. We'll see what errors show up in testing.
# Can hopefully just replace with a call to ast.literal_eval.
# Can't use json.loads because the function API still sometimes returns json strings
# with minor issues like trailing commas.
try:
json_str = json_str[json_str.index("{") : json_str.rindex("}") + 1]
return ast.literal_eval(json_str)
except json.decoder.JSONDecodeError as e:
try:
print(f"json decode error {e}. trying literal eval")
return ast.literal_eval(json_str)
except Exception:
breakpoint()
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ModelTokenizer,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import json_loads

_T = TypeVar("_T")
_P = ParamSpec("_P")
Expand Down Expand Up @@ -758,19 +759,18 @@ def _functions_compat_fix_kwargs(


def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
import json
import re
import uuid

logging.debug(f"Trying to extract tool calls from response:\n{response}")

if response[0] == "[":
tool_calls: list[AssistantToolCallDict] = json.loads(response)
tool_calls: list[AssistantToolCallDict] = json_loads(response)
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
raise ValueError("Could not find tool_calls block in response")
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))

for t in tool_calls:
t["id"] = str(uuid.uuid4())
Expand Down
39 changes: 36 additions & 3 deletions autogpts/autogpt/autogpt/json_utils/utilities.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
"""Utilities for the json_fixes package."""
import json

import io
import logging
import re
from typing import Any

import demjson3

logger = logging.getLogger(__name__)


def json_loads(json_str: str) -> Any:
"""Parse a JSON string, this function is tolerant
to minor issues in the JSON string:
kcze marked this conversation as resolved.
Show resolved Hide resolved
- Missing commas between elements
- Trailing commas or extra commas in objects
- Extraneous newlines and spaces outside of string literals
- Inconsistent spacing after colons and commas
- Missing closing brackets or braces
- Comments

Args:
json_str: The JSON string to parse.

Returns:
The parsed JSON object, same as built-in json.loads.
"""
error_buffer = io.StringIO()
json_result = demjson3.decode(
json_str, return_errors=True, write_errors=error_buffer
)

if error_buffer.getvalue():
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")

if json_result is None:
raise ValueError(f"Failed to parse JSON string: {json_str}")
Pwuts marked this conversation as resolved.
Show resolved Hide resolved

return json_result.object


def extract_dict_from_response(response_content: str) -> dict[str, Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
Expand All @@ -22,7 +55,7 @@ def extract_dict_from_response(response_content: str) -> dict[str, Any]:
if match:
response_content = match.group()

result = json.loads(response_content)
result = json_loads(response_content)
if not isinstance(result, dict):
raise ValueError(
f"Response '''{response_content}''' evaluated to "
Expand All @@ -46,7 +79,7 @@ def extract_list_from_response(response_content: str) -> list[Any]:
if match:
response_content = match.group()

result = json.loads(response_content)
result = json_loads(response_content)
if not isinstance(result, list):
raise ValueError(
f"Response '''{response_content}''' evaluated to "
Expand Down
15 changes: 13 additions & 2 deletions autogpts/autogpt/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions autogpts/autogpt/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ spacy = "^3.0.0"
tenacity = "^8.2.2"
tiktoken = "^0.5.0"
webdriver-manager = "*"
demjson3 = "^3.0.0"
kcze marked this conversation as resolved.
Show resolved Hide resolved

# OpenAI and Generic plugins import
openapi-python-client = "^0.14.0"
Expand Down
75 changes: 75 additions & 0 deletions autogpts/autogpt/tests/unit/test_json_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest

from autogpt.json_utils.utilities import json_loads

_JSON_FIXABLE: list[tuple[str, str]] = [
# Missing comma
('{"name": "John Doe" "age": 30,}', '{"name": "John Doe", "age": 30}'),
("[1, 2 3]", "[1, 2, 3]"),
# Trailing comma
('{"name": "John Doe", "age": 30,}', '{"name": "John Doe", "age": 30}'),
("[1, 2, 3,]", "[1, 2, 3]"),
# Extra comma in object
('{"name": "John Doe",, "age": 30}', '{"name": "John Doe", "age": 30}'),
# Extra newlines
('{"name": "John Doe",\n"age": 30}', '{"name": "John Doe", "age": 30}'),
("[1, 2,\n3]", "[1, 2, 3]"),
# Missing brace
('{"name": "John Doe", "age": 30', '{"name": "John Doe", "age": 30}'),
# Missing bracket
("[1, 2, 3", "[1, 2, 3]"),
# Different numerals
("[+1, ---2, .5, +-4.5, 123.]", "[1, -2, 0.5, -4.5, 123]"),
('{"bin": 0b1001, "hex": 0x1A, "oct": 0o17}', '{"bin": 9, "hex": 26, "oct": 15}'),
# Broken array
(
'[1, 2 3, "yes" true, false null, 25, {"obj": "var"}',
'[1, 2, 3, "yes", true, false, null, 25, {"obj": "var"}]',
),
# Mutliple problems
(
'{"name":"John Doe" "age": 30\n "empty": "","address": '
"// random comment\n"
'{"city": "New\nYork", "state": "NY",'
'"skills": ["Python" "C++", "Java",""],}',
'{"name": "John Doe", "age": 30, "empty": "", "address": '
'{"city": "New\nYork", "state": "NY", '
'"skills": ["Python", "C++", "Java", ""]}}',
),
# All good
(
'{"name": "John Doe", "age": 30, "address": '
'{"city": "New\nYork", "state": "NY", '
'"skills": ["Python", "C++", "Java"]}}',
'{"name": "John Doe", "age": 30, "address": '
'{"city": "New\nYork", "state": "NY", '
'"skills": ["Python", "C++", "Java"]}}',
),
]

_JSON_UNFIXABLE: list[tuple[str, str]] = [
# Broken booleans
("[TRUE, False]", "[true, false]"),
# Missing values in array
("[1, , 3]", "[1, 3]"),
# Leading zeros (are threated as octal)
("[0023, 015]", "[23, 15]"),
]


@pytest.fixture(params=_JSON_FIXABLE)
def fixable_json(request) -> tuple[str, str]:
return request.param


@pytest.fixture(params=_JSON_UNFIXABLE)
def unfixable_json(request) -> tuple[str, str]:
return request.param


def test_json_loads_fixable(fixable_json: tuple[str, str]):
assert json_loads(fixable_json[0]) == json_loads(fixable_json[1])


def test_json_loads_unfixable(unfixable_json: tuple[str, str]):
assert json_loads(unfixable_json[0]) != json_loads(unfixable_json[1])
Loading