-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
chat_completion_client_base.py
412 lines (341 loc) · 19 KB
/
chat_completion_client_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import copy
import logging
from abc import ABC
from collections.abc import AsyncGenerator, Callable
from functools import reduce
from typing import TYPE_CHECKING, Any, ClassVar
from opentelemetry.trace import Span, Tracer, get_tracer, use_span
from semantic_kernel.connectors.ai.function_call_behavior import FunctionCallBehavior
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_calling_utils import merge_function_results
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior, FunctionChoiceType
from semantic_kernel.const import AUTO_FUNCTION_INVOCATION_SPAN_NAME
from semantic_kernel.contents.annotation_content import AnnotationContent
from semantic_kernel.contents.file_reference_content import FileReferenceContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError
from semantic_kernel.services.ai_service_client_base import AIServiceClientBase
from semantic_kernel.utils.telemetry.model_diagnostics.gen_ai_attributes import AVAILABLE_FUNCTIONS
if TYPE_CHECKING:
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.kernel import Kernel
logger: logging.Logger = logging.getLogger(__name__)
tracer: Tracer = get_tracer(__name__)
class ChatCompletionClientBase(AIServiceClientBase, ABC):
"""Base class for chat completion AI services."""
# Connectors that support function calling should set this to True
SUPPORTS_FUNCTION_CALLING: ClassVar[bool] = False
# region Internal methods to be implemented by the derived classes
async def _inner_get_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
) -> list["ChatMessageContent"]:
"""Send a chat request to the AI service.
Args:
chat_history (ChatHistory): The chat history to send.
settings (PromptExecutionSettings): The settings for the request.
Returns:
chat_message_contents (list[ChatMessageContent]): The chat message contents representing the response(s).
"""
raise NotImplementedError("The _inner_get_chat_message_contents method is not implemented.")
async def _inner_get_streaming_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
"""Send a streaming chat request to the AI service.
Args:
chat_history (ChatHistory): The chat history to send.
settings (PromptExecutionSettings): The settings for the request.
Yields:
streaming_chat_message_contents (list[StreamingChatMessageContent]): The streaming chat message contents.
"""
raise NotImplementedError("The _inner_get_streaming_chat_message_contents method is not implemented.")
# Below is needed for mypy: https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
if False:
yield
# endregion
# region Public methods
async def get_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> list["ChatMessageContent"]:
"""Create chat message contents, in the number specified by the settings.
Args:
chat_history (ChatHistory): A list of chats in a chat_history object, that can be
rendered into messages from system, user, assistant and tools.
settings (PromptExecutionSettings): Settings for the request.
**kwargs (Any): The optional arguments.
Returns:
A list of chat message contents representing the response(s) from the LLM.
"""
# Create a copy of the settings to avoid modifying the original settings
settings = copy.deepcopy(settings)
if not self.SUPPORTS_FUNCTION_CALLING:
return await self._inner_get_chat_message_contents(chat_history, settings)
# For backwards compatibility we need to convert the `FunctionCallBehavior` to `FunctionChoiceBehavior`
# if this method is called with a `FunctionCallBehavior` object as part of the settings
if hasattr(settings, "function_call_behavior") and isinstance(
settings.function_call_behavior, FunctionCallBehavior
):
settings.function_choice_behavior = FunctionChoiceBehavior.from_function_call_behavior(
settings.function_call_behavior
)
kernel: "Kernel" = kwargs.get("kernel") # type: ignore
if settings.function_choice_behavior is not None:
if kernel is None:
raise ServiceInvalidExecutionSettingsError("The kernel is required for function calls.")
self._verify_function_choice_settings(settings)
if settings.function_choice_behavior and kernel:
# Configure the function choice behavior into the settings object
# that will become part of the request to the AI service
settings.function_choice_behavior.configure(
kernel=kernel,
update_settings_callback=self._update_function_choice_settings_callback(),
settings=settings,
)
if (
settings.function_choice_behavior is None
or not settings.function_choice_behavior.auto_invoke_kernel_functions
):
return await self._inner_get_chat_message_contents(chat_history, settings)
# Auto invoke loop
with use_span(self._start_auto_function_invocation_activity(kernel, settings), end_on_exit=True) as _:
for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts):
completions = await self._inner_get_chat_message_contents(chat_history, settings)
# Get the function call contents from the chat message. There is only one chat message,
# which should be checked in the `_verify_function_choice_settings` method.
function_calls = [item for item in completions[0].items if isinstance(item, FunctionCallContent)]
if (fc_count := len(function_calls)) == 0:
return completions
# Since we have a function call, add the assistant's tool call message to the history
chat_history.add_message(message=completions[0])
logger.info(f"processing {fc_count} tool calls in parallel.")
# This function either updates the chat history with the function call results
# or returns the context, with terminate set to True in which case the loop will
# break and the function calls are returned.
results = await asyncio.gather(
*[
kernel.invoke_function_call(
function_call=function_call,
chat_history=chat_history,
arguments=kwargs.get("arguments"),
function_call_count=fc_count,
request_index=request_index,
function_behavior=settings.function_choice_behavior,
)
for function_call in function_calls
],
)
if any(result.terminate for result in results if result is not None):
return merge_function_results(chat_history.messages[-len(results) :])
else:
# Do a final call, without function calling when the max has been reached.
self._reset_function_choice_settings(settings)
return await self._inner_get_chat_message_contents(chat_history, settings)
async def get_chat_message_content(
self, chat_history: "ChatHistory", settings: "PromptExecutionSettings", **kwargs: Any
) -> "ChatMessageContent | None":
"""This is the method that is called from the kernel to get a response from a chat-optimized LLM.
Args:
chat_history (ChatHistory): A list of chat chat_history, that can be rendered into a
set of chat_history, from system, user, assistant and function.
settings (PromptExecutionSettings): Settings for the request.
kwargs (Dict[str, Any]): The optional arguments.
Returns:
A string representing the response from the LLM.
"""
results = await self.get_chat_message_contents(chat_history=chat_history, settings=settings, **kwargs)
if results:
return results[0]
# this should not happen, should error out before returning an empty list
return None # pragma: no cover
async def get_streaming_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
"""Create streaming chat message contents, in the number specified by the settings.
Args:
chat_history (ChatHistory): A list of chat chat_history, that can be rendered into a
set of chat_history, from system, user, assistant and function.
settings (PromptExecutionSettings): Settings for the request.
kwargs (Dict[str, Any]): The optional arguments.
Yields:
A stream representing the response(s) from the LLM.
"""
# Create a copy of the settings to avoid modifying the original settings
settings = copy.deepcopy(settings)
if not self.SUPPORTS_FUNCTION_CALLING:
async for streaming_chat_message_contents in self._inner_get_streaming_chat_message_contents(
chat_history, settings
):
yield streaming_chat_message_contents
return
# For backwards compatibility we need to convert the `FunctionCallBehavior` to `FunctionChoiceBehavior`
# if this method is called with a `FunctionCallBehavior` object as part of the settings
if hasattr(settings, "function_call_behavior") and isinstance(
settings.function_call_behavior, FunctionCallBehavior
):
settings.function_choice_behavior = FunctionChoiceBehavior.from_function_call_behavior(
settings.function_call_behavior
)
kernel: "Kernel" = kwargs.get("kernel") # type: ignore
if settings.function_choice_behavior is not None:
if kernel is None:
raise ServiceInvalidExecutionSettingsError("The kernel is required for function calls.")
self._verify_function_choice_settings(settings)
if settings.function_choice_behavior and kernel:
# Configure the function choice behavior into the settings object
# that will become part of the request to the AI service
settings.function_choice_behavior.configure(
kernel=kernel,
update_settings_callback=self._update_function_choice_settings_callback(),
settings=settings,
)
if (
settings.function_choice_behavior is None
or not settings.function_choice_behavior.auto_invoke_kernel_functions
):
async for streaming_chat_message_contents in self._inner_get_streaming_chat_message_contents(
chat_history, settings
):
yield streaming_chat_message_contents
return
# Auto invoke loop
with use_span(self._start_auto_function_invocation_activity(kernel, settings), end_on_exit=True) as _:
for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts):
# Hold the messages, if there are more than one response, it will not be used, so we flatten
all_messages: list["StreamingChatMessageContent"] = []
function_call_returned = False
async for messages in self._inner_get_streaming_chat_message_contents(chat_history, settings):
for msg in messages:
if msg is not None:
all_messages.append(msg)
if any(isinstance(item, FunctionCallContent) for item in msg.items):
function_call_returned = True
yield messages
if not function_call_returned:
return
# There is one FunctionCallContent response stream in the messages, combining now to create
# the full completion depending on the prompt, the message may contain both function call
# content and others
full_completion: StreamingChatMessageContent = reduce(lambda x, y: x + y, all_messages)
function_calls = [item for item in full_completion.items if isinstance(item, FunctionCallContent)]
chat_history.add_message(message=full_completion)
fc_count = len(function_calls)
logger.info(f"processing {fc_count} tool calls in parallel.")
# This function either updates the chat history with the function call results
# or returns the context, with terminate set to True in which case the loop will
# break and the function calls are returned.
results = await asyncio.gather(
*[
kernel.invoke_function_call(
function_call=function_call,
chat_history=chat_history,
arguments=kwargs.get("arguments"),
function_call_count=fc_count,
request_index=request_index,
function_behavior=settings.function_choice_behavior,
)
for function_call in function_calls
],
)
if any(result.terminate for result in results if result is not None):
yield merge_function_results(chat_history.messages[-len(results) :]) # type: ignore
break
async def get_streaming_chat_message_content(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> AsyncGenerator["StreamingChatMessageContent | None", Any]:
"""This is the method that is called from the kernel to get a stream response from a chat-optimized LLM.
Args:
chat_history (ChatHistory): A list of chat chat_history, that can be rendered into a
set of chat_history, from system, user, assistant and function.
settings (PromptExecutionSettings): Settings for the request.
kwargs (Dict[str, Any]): The optional arguments.
Yields:
A stream representing the response(s) from the LLM.
"""
async for streaming_chat_message_contents in self.get_streaming_chat_message_contents(
chat_history, settings, **kwargs
):
if streaming_chat_message_contents:
yield streaming_chat_message_contents[0]
else:
# this should not happen, should error out before returning an empty list
yield None # pragma: no cover
# endregion
# region internal handlers
def _prepare_chat_history_for_request(
self,
chat_history: "ChatHistory",
role_key: str = "role",
content_key: str = "content",
) -> Any:
"""Prepare the chat history for a request.
Allowing customization of the key names for role/author, and optionally overriding the role.
ChatRole.TOOL messages need to be formatted different than system/user/assistant messages:
They require a "tool_call_id" and (function) "name" key, and the "metadata" key should
be removed. The "encoding" key should also be removed.
Override this method to customize the formatting of the chat history for a request.
Args:
chat_history (ChatHistory): The chat history to prepare.
role_key (str): The key name for the role/author.
content_key (str): The key name for the content/message.
Returns:
prepared_chat_history (Any): The prepared chat history for a request.
"""
return [
message.to_dict(role_key=role_key, content_key=content_key)
for message in chat_history.messages
if not isinstance(message, (AnnotationContent, FileReferenceContent))
]
def _verify_function_choice_settings(self, settings: "PromptExecutionSettings") -> None:
"""Additional verification to validate settings for function choice behavior.
Override this method to add additional verification for the settings.
Args:
settings (PromptExecutionSettings): The settings to verify.
"""
return
def _update_function_choice_settings_callback(
self,
) -> Callable[[FunctionCallChoiceConfiguration, "PromptExecutionSettings", FunctionChoiceType], None]:
"""Return the callback function to update the settings from a function call configuration.
Override this method to provide a custom callback function to
update the settings from a function call configuration.
"""
return lambda configuration, settings, choice_type: None
def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None:
"""Reset the settings updated by `_update_function_choice_settings_callback`.
Override this method to reset the settings updated by `_update_function_choice_settings_callback`.
Args:
settings (PromptExecutionSettings): The prompt execution settings to reset.
"""
return
def _start_auto_function_invocation_activity(self, kernel: "Kernel", settings: "PromptExecutionSettings") -> Span:
"""Start the auto function invocation activity.
Args:
kernel (Kernel): The kernel instance.
settings (PromptExecutionSettings): The prompt execution settings.
"""
span = tracer.start_span(AUTO_FUNCTION_INVOCATION_SPAN_NAME)
if settings.function_choice_behavior is not None:
available_functions = settings.function_choice_behavior.get_config(kernel).available_functions or []
span.set_attribute(
AVAILABLE_FUNCTIONS,
",".join([f.fully_qualified_name for f in available_functions]),
)
return span
# endregion