Skip to content

Commit

Permalink
Python: Azure AI Inference tracing SDK (#9693)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Addresses: #9413

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
The latest Azure AI Inference SDK has been released with the tracing
package. We have decided to upgrade to the latest so that we will no
longer need to instrument the Azure AI Inference connector with our own
model diagnostics module.

### Contribution Checklist
1. Upgrade to the latest Azure AI Inference SDK with the tracing
package.
2. Refactor the AI Inference connector to reduce duplicated code.
3. Some other minor fixes.

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
TaoChenOSU authored Nov 19, 2024
1 parent 1ce4769 commit fb5aa6f
Show file tree
Hide file tree
Showing 21 changed files with 1,406 additions and 1,020 deletions.
2 changes: 2 additions & 0 deletions python/.cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"hnsw",
"httpx",
"huggingface",
"Instrumentor",
"kernelfunction",
"logit",
"logprobs",
Expand All @@ -61,6 +62,7 @@
"serde",
"skprompt",
"templating",
"uninstrument",
"vectordb",
"vectorizer",
"vectorstoremodel",
Expand Down
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ dependencies = [
### Optional dependencies
[project.optional-dependencies]
azure = [
"azure-ai-inference >= 1.0.0b4",
"azure-ai-inference >= 1.0.0b6",
"azure-core-tracing-opentelemetry >= 1.0.0b11",
"azure-search-documents >= 11.6.0b4",
"azure-identity ~= 1.13",
"azure-cosmos ~= 4.7"
Expand Down
2 changes: 1 addition & 1 deletion python/samples/demos/telemetry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def main(scenario: Literal["ai_service", "kernel_function", "auto_function
with tracer.start_as_current_span("main") as current_span:
print(f"Trace ID: {format_trace_id(current_span.get_span_context().trace_id)}")

stream = True
stream = False

# Scenarios where telemetry is collected in the SDK, from the most basic to the most complex.
if scenario == "ai_service" or scenario == "all":
Expand Down
6 changes: 6 additions & 0 deletions python/samples/demos/telemetry/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def set_up_kernel() -> Kernel:
# All built-in AI services are instrumented with telemetry.
# Select any AI service to see the telemetry in action.
kernel.add_service(OpenAIChatCompletion(service_id="open_ai"))
# kernel.add_service(
# AzureAIInferenceChatCompletion(
# ai_model_id="serverless-deployment",
# service_id="azure-ai-inference",
# )
# )
# kernel.add_service(GoogleAIChatCompletion(service_id="google_ai"))

if (sample_plugin_path := get_sample_plugin_path()) is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ class AzureAIInferenceSettings(KernelBaseSettings):
env_prefix: ClassVar[str] = "AZURE_AI_INFERENCE_"

endpoint: HttpsUrl
api_key: SecretStr
api_key: SecretStr | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,111 @@
import asyncio
import contextlib
from abc import ABC
from typing import ClassVar
from enum import Enum
from typing import Any

from azure.ai.inference.aio import ChatCompletionsClient, EmbeddingsClient
from azure.core.credentials import AzureKeyCredential
from pydantic import ValidationError

from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_settings import AzureAIInferenceSettings
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.authentication.async_default_azure_credential_wrapper import (
AsyncDefaultAzureCredentialWrapper,
)
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT


class AzureAIInferenceClientType(Enum):
"""Client type for Azure AI Inference."""

ChatCompletions = "ChatCompletions"
Embeddings = "Embeddings"

@classmethod
def get_client_class(cls, client_type: "AzureAIInferenceClientType") -> Any:
"""Get the client class based on the client type."""
class_mapping = {
cls.ChatCompletions: ChatCompletionsClient,
cls.Embeddings: EmbeddingsClient,
}

return class_mapping[client_type]


@experimental_class
class AzureAIInferenceBase(KernelBaseModel, ABC):
"""Azure AI Inference Chat Completion Service."""

MODEL_PROVIDER_NAME: ClassVar[str] = "azureai"

client: ChatCompletionsClient | EmbeddingsClient
managed_client: bool = False

def __init__(
self,
client_type: AzureAIInferenceClientType,
api_key: str | None = None,
endpoint: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
client: ChatCompletionsClient | EmbeddingsClient | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Azure AI Inference Chat Completion service.
If no arguments are provided, the service will attempt to load the settings from the environment.
The following environment variables are used:
- AZURE_AI_INFERENCE_API_KEY
- AZURE_AI_INFERENCE_ENDPOINT
Args:
client_type (AzureAIInferenceClientType): The client type to use.
api_key (str | None): The API key for the Azure AI Inference service deployment. (Optional)
endpoint (str | None): The endpoint of the Azure AI Inference service deployment. (Optional)
env_file_path (str | None): The path to the environment file. (Optional)
env_file_encoding (str | None): The encoding of the environment file. (Optional)
client (ChatCompletionsClient | None): The Azure AI Inference client to use. (Optional)
**kwargs: Additional keyword arguments.
Raises:
ServiceInitializationError: If an error occurs during initialization.
"""
managed_client = client is None
if not client:
try:
azure_ai_inference_settings = AzureAIInferenceSettings.create(
api_key=api_key,
endpoint=endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Azure AI Inference settings: {e}") from e

endpoint = str(azure_ai_inference_settings.endpoint)
if azure_ai_inference_settings.api_key is not None:
client = AzureAIInferenceClientType.get_client_class(client_type)(
endpoint=endpoint,
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
else:
# Try to create the client with a DefaultAzureCredential
client = AzureAIInferenceClientType.get_client_class(client_type)(
endpoint=endpoint,
credential=AsyncDefaultAzureCredentialWrapper(),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)

super().__init__(
client=client,
managed_client=managed_client,
**kwargs,
)

def __del__(self) -> None:
"""Close the client when the object is deleted."""
with contextlib.suppress(Exception):
asyncio.get_running_loop().create_task(self.client.close())
if self.managed_client:
with contextlib.suppress(Exception):
asyncio.get_running_loop().create_task(self.client.close())
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,19 @@
StreamingChatChoiceUpdate,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from pydantic import ValidationError

from semantic_kernel.connectors.ai.azure_ai_inference import (
AzureAIInferenceChatPromptExecutionSettings,
AzureAIInferenceSettings,
from semantic_kernel.connectors.ai.azure_ai_inference import AzureAIInferenceChatPromptExecutionSettings
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_base import (
AzureAIInferenceBase,
AzureAIInferenceClientType,
)
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_base import AzureAIInferenceBase
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_tracing import AzureAIInferenceTracing
from semantic_kernel.connectors.ai.azure_ai_inference.services.utils import MESSAGE_CONVERTERS
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.completion_usage import CompletionUsage
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_calling_utils import update_settings_from_function_call_configuration
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
Expand All @@ -45,16 +42,8 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidExecutionSettingsError,
)
from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
trace_chat_completion,
trace_streaming_chat_completion,
)
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT

if TYPE_CHECKING:
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
Expand Down Expand Up @@ -97,39 +86,14 @@ def __init__(
Raises:
ServiceInitializationError: If an error occurs during initialization.
"""
if not client:
try:
azure_ai_inference_settings = AzureAIInferenceSettings.create(
api_key=api_key,
endpoint=endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Azure AI Inference settings: {e}") from e

endpoint_to_use: str = str(azure_ai_inference_settings.endpoint)
if azure_ai_inference_settings.api_key is not None:
client = ChatCompletionsClient(
endpoint=endpoint_to_use,
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
else:
# Try to create the client with a DefaultAzureCredential
client = (
ChatCompletionsClient(
endpoint=endpoint_to_use,
credential=DefaultAzureCredential(),
credential_scopes=["https://cognitiveservices.azure.com/.default"],
api_version=DEFAULT_AZURE_API_VERSION,
user_agent=SEMANTIC_KERNEL_USER_AGENT,
),
)

super().__init__(
ai_model_id=ai_model_id,
service_id=service_id or ai_model_id,
client_type=AzureAIInferenceClientType.ChatCompletions,
api_key=api_key,
endpoint=endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
client=client,
)

Expand All @@ -149,7 +113,6 @@ def service_url(self) -> str | None:
return None

@override
@trace_chat_completion(AzureAIInferenceBase.MODEL_PROVIDER_NAME)
async def _inner_get_chat_message_contents(
self,
chat_history: "ChatHistory",
Expand All @@ -160,17 +123,17 @@ async def _inner_get_chat_message_contents(
assert isinstance(settings, AzureAIInferenceChatPromptExecutionSettings) # nosec

assert isinstance(self.client, ChatCompletionsClient) # nosec
response: ChatCompletions = await self.client.complete(
messages=self._prepare_chat_history_for_request(chat_history),
model_extras=settings.extra_parameters,
**settings.prepare_settings_dict(),
)
with AzureAIInferenceTracing():
response: ChatCompletions = await self.client.complete(
messages=self._prepare_chat_history_for_request(chat_history),
model_extras=settings.extra_parameters,
**settings.prepare_settings_dict(),
)
response_metadata = self._get_metadata_from_response(response)

return [self._create_chat_message_content(response, choice, response_metadata) for choice in response.choices]

@override
@trace_streaming_chat_completion(AzureAIInferenceBase.MODEL_PROVIDER_NAME)
async def _inner_get_streaming_chat_message_contents(
self,
chat_history: "ChatHistory",
Expand All @@ -181,12 +144,13 @@ async def _inner_get_streaming_chat_message_contents(
assert isinstance(settings, AzureAIInferenceChatPromptExecutionSettings) # nosec

assert isinstance(self.client, ChatCompletionsClient) # nosec
response: AsyncStreamingChatCompletions = await self.client.complete(
stream=True,
messages=self._prepare_chat_history_for_request(chat_history),
model_extras=settings.extra_parameters,
**settings.prepare_settings_dict(),
)
with AzureAIInferenceTracing():
response: AsyncStreamingChatCompletions = await self.client.complete(
stream=True,
messages=self._prepare_chat_history_for_request(chat_history),
model_extras=settings.extra_parameters,
**settings.prepare_settings_dict(),
)

async for chunk in response:
if len(chunk.choices) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,17 @@

from azure.ai.inference.aio import EmbeddingsClient
from azure.ai.inference.models import EmbeddingsResult
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from numpy import array, ndarray
from pydantic import ValidationError

from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_prompt_execution_settings import (
AzureAIInferenceEmbeddingPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_settings import AzureAIInferenceSettings
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_base import AzureAIInferenceBase
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_base import (
AzureAIInferenceBase,
AzureAIInferenceClientType,
)
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT

if TYPE_CHECKING:
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
Expand Down Expand Up @@ -63,37 +59,14 @@ def __init__(
Raises:
ServiceInitializationError: If an error occurs during initialization.
"""
if not client:
try:
azure_ai_inference_settings = AzureAIInferenceSettings.create(
api_key=api_key,
endpoint=endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Azure AI Inference settings: {e}") from e

endpoint = str(azure_ai_inference_settings.endpoint)
if azure_ai_inference_settings.api_key is not None:
client = EmbeddingsClient(
endpoint=endpoint,
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
else:
# Try to create the client with a DefaultAzureCredential
client = EmbeddingsClient(
endpoint=endpoint,
credential=DefaultAzureCredential(),
credential_scopes=["https://cognitiveservices.azure.com/.default"],
api_version=DEFAULT_AZURE_API_VERSION,
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)

super().__init__(
ai_model_id=ai_model_id,
service_id=service_id or ai_model_id,
client_type=AzureAIInferenceClientType.Embeddings,
api_key=api_key,
endpoint=endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
client=client,
)

Expand Down
Loading

0 comments on commit fb5aa6f

Please sign in to comment.