Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anthropic[patch]: fix tool call and tool res image_url handling #26587

Merged
merged 4 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,65 +194,81 @@ def _format_messages(

# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append({"type": "text", "text": item})
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url":
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
if "type" not in block:
raise ValueError("Dict content block must have a type key")
elif block["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
source = _format_image(block["image_url"]["url"])
content.append({"type": "image", "source": source})
elif item["type"] == "tool_use":
elif block["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content block
# exists, the tool_call is preferred.
if isinstance(message, AIMessage) and item["id"] in [
if isinstance(message, AIMessage) and block["id"] in [
tc["id"] for tc in message.tool_calls
]:
overlapping = [
tc
for tc in message.tool_calls
if tc["id"] == item["id"]
if tc["id"] == block["id"]
]
content.extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
)
else:
item.pop("text", None)
content.append(item)
elif item["type"] == "text":
text = item.get("text", "")
block.pop("text", None)
content.append(block)
elif block["type"] == "text":
text = block.get("text", "")
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip():
content.append(
{
k: v
for k, v in item.items()
for k, v in block.items()
if k in ("type", "text", "cache_control")
}
)
elif block["type"] == "tool_result":
tool_content = _format_messages(
[HumanMessage(block["content"])]
)[1][0]["content"]
content.append({**block, **{"content": tool_content}})
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to get rid of the else branches if at all possible (and instead have an elif with an instance check)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no other known types so alternative would be to raise an error on else

content.append(item)
content.append(block)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
f"Content blocks must be str or dict, instead was: "
f"{type(block)}"
)
elif isinstance(message, AIMessage) and message.tool_calls:
content = (
[]
if not message.content
else [{"type": "text", "text": message.content}]
)
# Note: Anthropic can't have invalid tool calls as presently defined,
# since the model already returns dicts args not JSON strings, and invalid
# tool calls are those with invalid JSON for args.
content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls)
else:
content = message.content

# Ensure all tool_calls have a tool_use content block
if isinstance(message, AIMessage) and message.tool_calls:
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
content = content or []
content = (
[{"type": "text", "text": message.content}]
if isinstance(content, str) and content
else content
)
tool_use_ids = [
cast(dict, block)["id"]
for block in content
if cast(dict, block)["type"] == "tool_use"
]
missing_tool_calls = [
tc for tc in message.tool_calls if tc["id"] not in tool_use_ids
]
cast(list, content).extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(missing_tool_calls)
)

formatted_messages.append({"role": role, "content": content})
return system, formatted_messages

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def chat_model_params(self) -> dict:
def supports_image_inputs(self) -> bool:
return True

@property
def supports_image_tool_message(self) -> bool:
return True

@property
def supports_anthropic_inputs(self) -> bool:
return True
87 changes: 80 additions & 7 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,36 @@ def test_convert_to_anthropic_tool(
def test__format_messages_with_tool_calls() -> None:
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
ai = AIMessage( # type: ignore[misc]
"",
ai = AIMessage(
"", # with empty string
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
)
tool = ToolMessage( # type: ignore[misc]
ai2 = AIMessage(
[], # with empty list
tool_calls=[{"name": "bar", "id": "2", "args": {"baz": "buzz"}}],
)
tool = ToolMessage(
"blurb",
tool_call_id="1",
)
messages = [system, human, ai, tool]
tool_image_url = ToolMessage(
[{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,...."}}],
tool_call_id="2",
)
tool_image = ToolMessage(
[
{
"type": "image",
"source": {
"data": "....",
"type": "base64",
"media_type": "image/jpeg",
},
}
],
tool_call_id="3",
)
messages = [system, human, ai, tool, ai2, tool_image_url, tool_image]
expected = (
"fuzz",
[
Expand All @@ -401,6 +422,52 @@ def test__format_messages_with_tool_calls() -> None:
}
],
},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"name": "bar",
"id": "2",
"input": {"baz": "buzz"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"content": [
{
"type": "image",
"source": {
"data": "....",
"type": "base64",
"media_type": "image/jpeg",
},
}
],
"tool_use_id": "2",
"is_error": False,
},
{
"type": "tool_result",
"content": [
{
"type": "image",
"source": {
"data": "....",
"type": "base64",
"media_type": "image/jpeg",
},
}
],
"tool_use_id": "3",
"is_error": False,
},
],
},
],
)
actual = _format_messages(messages)
Expand Down Expand Up @@ -454,8 +521,6 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
def test__format_messages_with_list_content_and_tool_calls() -> None:
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
# If content and tool_calls are specified and content is a list, then content is
# preferred.
ai = AIMessage( # type: ignore[misc]
[{"type": "text", "text": "thought"}],
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
Expand All @@ -471,7 +536,15 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [{"type": "text", "text": "thought"}],
"content": [
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "buzz"},
},
],
},
{
"role": "user",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def chat_model_class(self) -> Type[BaseChatModel]:

@property
def chat_model_params(self) -> dict:
return {"model": "gpt-4o", "stream_usage": True}
return {"model": "gpt-4o-mini", "stream_usage": True}

@property
def supports_image_inputs(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,37 @@ def test_image_inputs(self, model: BaseChatModel) -> None:
)
model.invoke([message])

def test_image_tool_message(self, model: BaseChatModel) -> None:
if not self.supports_image_tool_message:
return
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
messages = [
HumanMessage("get a random image using the tool and describe the weather"),
AIMessage(
[],
tool_calls=[
{"type": "tool_call", "id": "1", "name": "random_image", "args": {}}
],
),
ToolMessage(
content=[
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
],
tool_call_id="1",
name="random_image",
),
]

def random_image() -> str:
"""Return a random image."""
return ""

model.bind_tools([random_image]).invoke(messages)

def test_anthropic_inputs(self, model: BaseChatModel) -> None:
if not self.supports_anthropic_inputs:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def returns_usage_metadata(self) -> bool:
def supports_anthropic_inputs(self) -> bool:
return False

@property
def supports_image_tool_message(self) -> bool:
return False


class ChatModelUnitTests(ChatModelTests):
@property
Expand Down
Loading