Skip to content

Commit

Permalink
Python: small fix for CH serialization (#5738)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Noticed a small issue when ChatHistory is created with different types
of messages, and then serialized.
This fixed that and adds a test for that case.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
eavanvalkenburg authored Apr 2, 2024
1 parent c65644a commit 290f44d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def _validate_tool_calls(cls, tool_calls: Any) -> Optional[List[ToolCall]]:
if isinstance(tool_calls, list):
for index, call in enumerate(tool_calls):
if not isinstance(call, ToolCall):
tool_calls[index] = ToolCall.model_validate_json(call)
if isinstance(call, dict):
tool_calls[index] = ToolCall.model_validate(call)
else:
tool_calls[index] = ToolCall.model_validate_json(call)
return tool_calls
if isinstance(tool_calls, str):
return [ToolCall.model_validate_json(call) for call in tool_calls.split("|")]
Expand All @@ -53,6 +56,8 @@ def _validate_function_call(cls, function_call: Any) -> Optional[FunctionCall]:
return None
if isinstance(function_call, FunctionCall):
return function_call
if isinstance(function_call, dict):
return FunctionCall.model_validate(function_call)
return FunctionCall.model_validate_json(function_call)

@staticmethod
Expand Down
16 changes: 15 additions & 1 deletion python/semantic_kernel/contents/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xml.etree.ElementTree import Element, tostring

from defusedxml.ElementTree import XML, ParseError
from pydantic import field_validator

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.chat_message_content_base import ChatMessageContentBase
Expand Down Expand Up @@ -34,7 +35,7 @@ class ChatHistory(KernelBaseModel):
messages (List[ChatMessageContent]): The list of chat messages in the history.
"""

messages: list["ChatMessageContent"]
messages: list[ChatMessageContent]
message_type: TYPES_CHAT_MESSAGE_CONTENT = CHAT_MESSAGE_CONTENT

def __init__(self, **data: Any):
Expand Down Expand Up @@ -75,6 +76,19 @@ def __init__(self, **data: Any):
data["messages"] = []
super().__init__(**data)

@field_validator("messages", mode="before")
@classmethod
def _validate_messages(cls, messages: List[ChatMessageContent]) -> List[ChatMessageContent]:
if not messages:
return messages
out_msgs: List[ChatMessageContent] = []
for message in messages:
if isinstance(message, dict):
out_msgs.append(ChatMessageContentBase.from_dict(message))
else:
out_msgs.append(message)
return out_msgs

def add_system_message(self, content: str, **kwargs: Any) -> None:
"""Add a system message to the chat history."""
self.add_message(message=self._prepare_for_add(ChatRole.SYSTEM, content, **kwargs))
Expand Down
13 changes: 12 additions & 1 deletion python/tests/unit/contents/test_chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_add_system_message(chat_history: ChatHistory):
assert chat_history.messages[-1].role == ChatRole.SYSTEM


def test_add_system_message_at_init(chat_history: ChatHistory):
def test_add_system_message_at_init():
content = "System message"
chat_history = ChatHistory(system_message=content)
assert chat_history.messages[-1].content == content
Expand Down Expand Up @@ -190,6 +190,17 @@ def test_serialize(): # ignore: E501
)


def test_serialize_and_deserialize_to_chat_history_mixed_content():
system_msg = "a test system prompt"
msgs = [ChatMessageContent(role=ChatRole.USER, content=f"Message {i}") for i in range(3)]
msgs.extend([OpenAIChatMessageContent(role=ChatRole.USER, content=f"Message {i}") for i in range(3)])
msgs.extend([AzureChatMessageContent(role=ChatRole.USER, content=f"Message {i}") for i in range(3)])
chat_history = ChatHistory(messages=msgs, system_message=system_msg)
json_str = chat_history.serialize()
new_chat_history = ChatHistory.restore_chat_history(json_str)
assert new_chat_history == chat_history


def test_serialize_and_deserialize_to_chat_history():
system_msg = "a test system prompt"
msgs = [ChatMessageContent(role=ChatRole.USER, content=f"Message {i}") for i in range(3)]
Expand Down

0 comments on commit 290f44d

Please sign in to comment.