-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
chat_history.py
307 lines (253 loc) · 12.3 KB
/
chat_history.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import logging
from typing import Any, Iterator, List
from xml.etree.ElementTree import Element, tostring
from defusedxml.ElementTree import XML, ParseError
from pydantic import field_validator
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.chat_message_content_base import ChatMessageContentBase
from semantic_kernel.contents.chat_role import ChatRole
from semantic_kernel.contents.const import (
CHAT_MESSAGE_CONTENT,
ROOT_KEY_HISTORY,
ROOT_KEY_MESSAGE,
TYPES_CHAT_MESSAGE_CONTENT,
)
from semantic_kernel.exceptions import ContentInitializationError, ContentSerializationError
from semantic_kernel.kernel_pydantic import KernelBaseModel
logger = logging.getLogger(__name__)
class ChatHistory(KernelBaseModel):
"""
This class holds the history of chat messages from a chat conversation.
Note: the constructor takes a system_message parameter, which is not part
of the class definition. This is to allow the system_message to be passed in
as a keyword argument, but not be part of the class definition.
Attributes:
messages (List[ChatMessageContent]): The list of chat messages in the history.
"""
messages: list[ChatMessageContent]
message_type: TYPES_CHAT_MESSAGE_CONTENT = CHAT_MESSAGE_CONTENT
def __init__(self, **data: Any):
"""
Initializes a new instance of the ChatHistory class, optionally incorporating a message and/or
a system message at the beginning of the chat history.
This constructor allows for flexible initialization with chat messages and an optional messages or a
system message. If both 'messages' (a list of ChatMessageContent instances) and 'system_message' are
provided, the 'system_message' is prepended to the list of messages, ensuring it appears as the first
message in the history. If only 'system_message' is provided without any 'messages', the chat history is
initialized with the 'system_message' as its first item. If 'messages' are provided without a
'system_message', the chat history is initialized with the provided messages as is.
Parameters:
- **data: Arbitrary keyword arguments. The constructor looks for two optional keys:
- 'messages': Optional[List[ChatMessageContent]], a list of chat messages to include in the history.
- 'system_message' Optional[str]: An optional string representing a system-generated message to be
included at the start of the chat history.
Note: The 'system_message' is not retained as part of the class's attributes; it's used during
initialization and then discarded. The rest of the keyword arguments are passed to the superclass
constructor and handled according to the Pydantic model's behavior.
"""
system_message_content = data.pop("system_message", None)
message_type = data.get("message_type", CHAT_MESSAGE_CONTENT)
if system_message_content:
system_message = ChatMessageContentBase.from_fields(
role=ChatRole.SYSTEM, content=system_message_content, type=message_type
)
if "messages" in data:
data["messages"] = [system_message] + data["messages"]
else:
data["messages"] = [system_message]
if "messages" not in data:
data["messages"] = []
super().__init__(**data)
@field_validator("messages", mode="before")
@classmethod
def _validate_messages(cls, messages: List[ChatMessageContent]) -> List[ChatMessageContent]:
if not messages:
return messages
out_msgs: List[ChatMessageContent] = []
for message in messages:
if isinstance(message, dict):
out_msgs.append(ChatMessageContentBase.from_dict(message))
else:
out_msgs.append(message)
return out_msgs
def add_system_message(self, content: str, **kwargs: Any) -> None:
"""Add a system message to the chat history."""
self.add_message(message=self._prepare_for_add(ChatRole.SYSTEM, content, **kwargs))
def add_user_message(self, content: str, **kwargs: Any) -> None:
"""Add a user message to the chat history."""
self.add_message(message=self._prepare_for_add(ChatRole.USER, content, **kwargs))
def add_assistant_message(self, content: str, **kwargs: Any) -> None:
"""Add an assistant message to the chat history."""
self.add_message(message=self._prepare_for_add(ChatRole.ASSISTANT, content, **kwargs))
def add_tool_message(
self, content: str | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any
) -> None:
"""Add a tool message to the chat history."""
self.add_message(message=self._prepare_for_add(ChatRole.TOOL, content, **kwargs), metadata=metadata)
def add_message(
self,
message: "ChatMessageContent" | dict[str, Any],
encoding: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Add a message to the history.
This method accepts either a ChatMessageContent instance or a
dictionary with the necessary information to construct a ChatMessageContent instance.
Args:
message (Union[ChatMessageContent, dict]): The message to add, either as
a pre-constructed ChatMessageContent instance or a dictionary specifying 'role' and 'content'.
encoding (Optional[str]): The encoding of the message. Required if 'message' is a dict.
metadata (Optional[dict[str, Any]]): Any metadata to attach to the message. Required if 'message' is a dict.
"""
from semantic_kernel.contents.chat_message_content import ChatMessageContent
if isinstance(message, ChatMessageContent):
self.messages.append(message)
return
if "role" not in message:
raise ContentInitializationError(f"Dictionary must contain at least the role. Got: {message}")
if encoding:
message["encoding"] = encoding
if metadata:
message["metadata"] = metadata
if "type" not in message:
message["type"] = self.message_type
self.messages.append(ChatMessageContentBase.from_dict(message))
def _prepare_for_add(self, role: ChatRole, content: str | None = None, **kwargs: Any) -> dict[str, str]:
"""Prepare a message to be added to the history."""
kwargs["role"] = role
kwargs["content"] = content
return kwargs
def remove_message(self, message: "ChatMessageContent") -> bool:
"""Remove a message from the history.
Args:
message (ChatMessageContent): The message to remove.
Returns:
bool: True if the message was removed, False if the message was not found.
"""
try:
self.messages.remove(message)
return True
except ValueError:
return False
def __len__(self) -> int:
"""Return the number of messages in the history."""
return len(self.messages)
def __getitem__(self, index: int) -> "ChatMessageContent":
"""Get a message from the history using the [] operator.
Args:
index (int): The index of the message to get.
Returns:
ChatMessageContent: The message at the specified index.
"""
return self.messages[index]
def __contains__(self, item: "ChatMessageContent") -> bool:
"""Check if a message is in the history.
Args:
item (ChatMessageContent): The message to check for.
Returns:
bool: True if the message is in the history, False otherwise.
"""
return item in self.messages
def __str__(self) -> str:
"""Return a string representation of the history."""
chat_history_xml = Element(ROOT_KEY_HISTORY)
for message in self.messages:
chat_history_xml.append(message.to_element(root_key=ROOT_KEY_MESSAGE))
return tostring(chat_history_xml, encoding="unicode", short_empty_elements=True)
def __iter__(self) -> Iterator["ChatMessageContent"]:
"""Return an iterator over the messages in the history."""
return iter(self.messages)
def __eq__(self, other: Any) -> bool:
"""Check if two ChatHistory instances are equal."""
if not isinstance(other, ChatHistory):
return False
return self.messages == other.messages
@classmethod
def from_rendered_prompt(cls, rendered_prompt: str, message_type: str = CHAT_MESSAGE_CONTENT) -> "ChatHistory":
"""
Create a ChatHistory instance from a rendered prompt.
Args:
rendered_prompt (str): The rendered prompt to convert to a ChatHistory instance.
Returns:
ChatHistory: The ChatHistory instance created from the rendered prompt.
"""
messages: List[ChatMessageContent] = []
prompt = rendered_prompt.strip()
try:
xml_prompt = XML(text=f"<prompt>{prompt}</prompt>")
except ParseError:
logger.info(f"Could not parse prompt {prompt} as xml, treating as text")
return cls(
messages=[ChatMessageContentBase.from_fields(role=ChatRole.USER, content=prompt, type=message_type)]
)
if xml_prompt.text and xml_prompt.text.strip():
messages.append(
ChatMessageContentBase.from_fields(
role=ChatRole.SYSTEM, content=xml_prompt.text.strip(), type=message_type
)
)
for item in xml_prompt:
if item.tag == ROOT_KEY_MESSAGE:
messages.append(ChatMessageContentBase.from_element(item))
elif item.tag == ROOT_KEY_HISTORY:
for message in item:
messages.append(ChatMessageContentBase.from_element(message))
if item.tail and item.tail.strip():
messages.append(
ChatMessageContentBase.from_fields(role=ChatRole.USER, content=item.tail.strip(), type=message_type)
)
if len(messages) == 1 and messages[0].role == ChatRole.SYSTEM:
messages[0].role = ChatRole.USER
return cls(messages=messages, message_type=message_type)
def serialize(self) -> str:
"""
Serializes the ChatHistory instance to a JSON string.
Returns:
str: A JSON string representation of the ChatHistory instance.
Raises:
ValueError: If the ChatHistory instance cannot be serialized to JSON.
"""
try:
return self.model_dump_json(indent=4, exclude_none=True)
except Exception as e:
raise ContentSerializationError(f"Unable to serialize ChatHistory to JSON: {e}") from e
@classmethod
def restore_chat_history(cls, chat_history_json: str) -> "ChatHistory":
"""
Restores a ChatHistory instance from a JSON string.
Args:
chat_history_json (str): The JSON string to deserialize
into a ChatHistory instance.
Returns:
ChatHistory: The deserialized ChatHistory instance.
Raises:
ValueError: If the JSON string is invalid or the deserialized data
fails validation.
"""
try:
return ChatHistory.model_validate_json(chat_history_json)
except Exception as e:
raise ContentInitializationError(f"Invalid JSON format: {e}")
def store_chat_history_to_file(self, file_path: str) -> None:
"""
Stores the serialized ChatHistory to a file.
Args:
file_path (str): The path to the file where the serialized data will be stored.
"""
json_str = self.serialize()
with open(file_path, "w") as file:
file.write(json_str)
@classmethod
def load_chat_history_from_file(cls, file_path: str) -> "ChatHistory":
"""
Loads the ChatHistory from a file.
Args:
file_path (str): The path to the file from which to load the ChatHistory.
Returns:
ChatHistory: The deserialized ChatHistory instance.
"""
with open(file_path, "r") as file:
json_str = file.read()
return cls.restore_chat_history(json_str)