Skip to content

Commit

Permalink
Enable bidirectional adapters between SK and Microsoft.Extensions.AI …
Browse files Browse the repository at this point in the history
…interfaces
  • Loading branch information
stephentoub committed Oct 26, 2024
1 parent ea5ceb1 commit 11fbef3
Show file tree
Hide file tree
Showing 8 changed files with 910 additions and 3 deletions.
1 change: 1 addition & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
<PackageVersion Include="Microsoft.DeepDev.TokenizerLib" Version="1.3.3" />
<PackageVersion Include="SharpToken" Version="2.0.3" />
<!-- Microsoft.Extensions.* -->
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Binder" Version="8.0.2" />
<PackageVersion Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="8.0.0" />
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.Services;

namespace Microsoft.SemanticKernel.Embeddings;

Expand All @@ -24,7 +28,6 @@ public static class EmbeddingGenerationExtensions
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>A list of embedding structs representing the input <paramref name="value"/>.</returns>
[Experimental("SKEXP0001")]
public static async Task<ReadOnlyMemory<TEmbedding>> GenerateEmbeddingAsync<TValue, TEmbedding>(
this IEmbeddingGenerationService<TValue, TEmbedding> generator,
TValue value,
Expand All @@ -35,4 +38,125 @@ public static async Task<ReadOnlyMemory<TEmbedding>> GenerateEmbeddingAsync<TVal
Verify.NotNull(generator);
return (await generator.GenerateEmbeddingsAsync([value], kernel, cancellationToken).ConfigureAwait(false)).FirstOrDefault();
}

/// <summary>Creates an <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> for the specified <see cref="IEmbeddingGenerationService{TValue, TEmbedding}"/>.</summary>
/// <param name="service">The embedding generation service to be represented as an embedding generator.</param>
/// <returns>
/// The <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>. If the <paramref name="service"/> is an <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>,
/// the <paramref name="service"/> will be returned. Otherwise, a new <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> will be created that wraps the <paramref name="service"/>.
/// </returns>
public static IEmbeddingGenerator<TValue, Embedding<TEmbedding>> AsEmbeddingGenerator<TValue, TEmbedding>(
this IEmbeddingGenerationService<TValue, TEmbedding> service)
where TEmbedding : unmanaged
{
Verify.NotNull(service);

return service is IEmbeddingGenerator<TValue, Embedding<TEmbedding>> embeddingGenerator ?
embeddingGenerator :
new EmbeddingGenerationServiceEmbeddingGenerator<TValue, TEmbedding>(service);
}

/// <summary>Creates an <see cref="IEmbeddingGenerationService{TInput, TEmbedding}"/> for the specified <see cref="IEmbeddingGenerator{TValue, TEmbedding}"/>.</summary>
/// <param name="generator">The embedding generator to be represented as an embedding generation service.</param>
/// <param name="serviceProvider">An optional <see cref="IServiceProvider"/> that can be used to resolve services to use in the instance.</param>
/// <returns>
/// The <see cref="IEmbeddingGenerationService{TInput, TEmbedding}"/>. If the <paramref name="generator"/> is an <see cref="IEmbeddingGenerationService{TInput, TEmbedding}"/>,
/// the <paramref name="generator"/> will be returned. Otherwise, a new <see cref="IEmbeddingGenerationService{TValue, TEmbedding}"/> will be created that wraps the <paramref name="generator"/>.
/// </returns>
public static IEmbeddingGenerationService<TValue, TEmbedding> AsEmbeddingGenerationService<TValue, TEmbedding>(
this IEmbeddingGenerator<TValue, Embedding<TEmbedding>> generator,
IServiceProvider? serviceProvider = null)
where TEmbedding : unmanaged
{
Verify.NotNull(generator);

return generator is IEmbeddingGenerationService<TValue, TEmbedding> service ?
service :
new EmbeddingGeneratorEmbeddingGenerationService<TValue, TEmbedding>(generator, serviceProvider);
}

/// <summary>Provides an implementation of <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> around an <see cref="IEmbeddingGenerationService{TValue, TEmbedding}"/>.</summary>
private sealed class EmbeddingGenerationServiceEmbeddingGenerator<TValue, TEmbedding> : IEmbeddingGenerator<TValue, Embedding<TEmbedding>>
where TEmbedding : unmanaged
{
/// <summary>The wrapped <see cref="IEmbeddingGenerationService{TValue, TEmbedding}"/></summary>
private readonly IEmbeddingGenerationService<TValue, TEmbedding> _service;

/// <summary>Initializes the <see cref="EmbeddingGenerationServiceEmbeddingGenerator{TValue, TEmbedding}"/> for <paramref name="service"/>.</summary>
public EmbeddingGenerationServiceEmbeddingGenerator(IEmbeddingGenerationService<TValue, TEmbedding> service)
{
this._service = service;
this.Metadata = new EmbeddingGeneratorMetadata(
service.GetType().Name,
service.GetEndpoint() is string endpoint ? new Uri(endpoint) : null,
service.GetModelId());
}

/// <inheritdoc />
public EmbeddingGeneratorMetadata Metadata { get; }

/// <inheritdoc />
public void Dispose()
{
(this._service as IDisposable)?.Dispose();
}

/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<TEmbedding>>> GenerateAsync(IEnumerable<TValue> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
{
IList<ReadOnlyMemory<TEmbedding>> result = await this._service.GenerateEmbeddingsAsync(values.ToList(), kernel: null, cancellationToken).ConfigureAwait(false);
return new(result.Select(e => new Embedding<TEmbedding>(e)));
}

/// <inheritdoc />
public TService? GetService<TService>(object? key = null) where TService : class
{
return
typeof(TService) == typeof(IEmbeddingGenerator<TValue, Embedding<TEmbedding>>) ? (TService)(object)this :
this._service as TService;
}
}

/// <summary>Provides an implementation of <see cref="IEmbeddingGenerationService{TInput, TEmbedding}"/> around an <see cref="EmbeddingGeneratorEmbeddingGenerationService{TValue, TEmbedding}"/>.</summary>
private sealed class EmbeddingGeneratorEmbeddingGenerationService<TValue, TEmbedding> : IEmbeddingGenerationService<TValue, TEmbedding>
where TEmbedding : unmanaged
{
/// <summary>The wrapped <see cref="IEmbeddingGenerator{TValue, TEmbedding}"/></summary>
private readonly IEmbeddingGenerator<TValue, Embedding<TEmbedding>> _generator;

/// <summary>Initializes the <see cref="EmbeddingGeneratorEmbeddingGenerationService{TValue, TEmbedding}"/> for <paramref name="generator"/>.</summary>
public EmbeddingGeneratorEmbeddingGenerationService(
IEmbeddingGenerator<TValue, Embedding<TEmbedding>> generator, IServiceProvider? serviceProvider)
{
// Store the generator.
this._generator = generator;

// Initialize the attributes.
var attrs = new Dictionary<string, object?>();
this.Attributes = new ReadOnlyDictionary<string, object?>(attrs);

var metadata = generator.Metadata;
if (metadata.ProviderUri is not null)
{
attrs[AIServiceExtensions.EndpointKey] = metadata.ProviderUri.ToString();
}
if (metadata.ModelId is not null)
{
attrs[AIServiceExtensions.ModelIdKey] = metadata.ModelId;
}
}

/// <inheritdoc />
public IReadOnlyDictionary<string, object?> Attributes { get; }

/// <inheritdoc />
public async Task<IList<ReadOnlyMemory<TEmbedding>>> GenerateEmbeddingsAsync(IList<TValue> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
Verify.NotNull(data);

var embeddings = await this._generator.GenerateAsync(data, cancellationToken: cancellationToken).ConfigureAwait(false);

return embeddings.Select(e => e.Vector).ToList();
}
}
}
62 changes: 62 additions & 0 deletions dotnet/src/SemanticKernel.Abstractions/AbstractionsJsonContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;

namespace Microsoft.SemanticKernel;

[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(PromptExecutionSettings))]
internal sealed partial class AbstractionsJsonContext : JsonSerializerContext
{
/// <summary>Gets the <see cref="JsonSerializerOptions"/> singleton used as the default in JSON serialization operations.</summary>
private static readonly JsonSerializerOptions s_defaultToolJsonOptions = CreateDefaultToolJsonOptions();

/// <summary>Gets JSON type information for the specified type.</summary>
/// <remarks>
/// This first tries to get the type information from <paramref name="firstOptions"/>,
/// falling back to <see cref="s_defaultToolJsonOptions"/> if it can't.
/// </remarks>
public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions)
{
return firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
info :
s_defaultToolJsonOptions.GetTypeInfo(type);
}

/// <summary>Creates the default <see cref="JsonSerializerOptions"/> to use for serialization-related operations.</summary>
[UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.

if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};

options.MakeReadOnly();
return options;
}

return Default.Options;
}
}
69 changes: 69 additions & 0 deletions dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.Diagnostics;
Expand Down Expand Up @@ -537,4 +538,72 @@ private void LogFunctionResult(ILogger logger, FunctionResult functionResult)
logger.LogFunctionResultValue(functionResult);
}
}

/// <summary>Creates an <see cref="AIFunction"/> for this <see cref="KernelFunction"/>.</summary>
/// <param name="kernel">
/// The <see cref="Kernel"/> instance to pass to the <see cref="KernelFunction"/> when it's invoked as part of the <see cref="AIFunction"/>'s invocation.
/// </param>
/// <returns>An instance of <see cref="AIFunction"/> that, when invoked, will in turn invoke the current <see cref="KernelFunction"/>.</returns>
[Experimental("SKEXP0001")]
public AIFunction AsAIFunction(Kernel? kernel = null)
{
return new KernelAIFunction(this, kernel);
}

/// <summary>An <see cref="AIFunction"/> wrapper around a <see cref="KernelFunction"/>.</summary>
private sealed class KernelAIFunction : AIFunction
{
private readonly KernelFunction _kernelFunction;
private readonly Kernel? _kernel;

public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel)
{
this._kernelFunction = kernelFunction;
this._kernel = kernel;

string name = string.IsNullOrWhiteSpace(kernelFunction.PluginName) ?
kernelFunction.Name :
$"{kernelFunction.PluginName}_{kernelFunction.Name}";

this.Metadata = new AIFunctionMetadata(name)
{
Description = kernelFunction.Description,

Parameters = kernelFunction.Metadata.Parameters.Select(p => new AIFunctionParameterMetadata(p.Name)
{
Description = p.Description,
ParameterType = p.ParameterType,
IsRequired = p.IsRequired,
HasDefaultValue = p.DefaultValue is not null,
DefaultValue = p.DefaultValue,
Schema = p.Schema?.RootElement,
}).ToList(),

ReturnParameter = new AIFunctionReturnParameterMetadata()
{
Description = kernelFunction.Metadata.ReturnParameter.Description,
ParameterType = kernelFunction.Metadata.ReturnParameter.ParameterType,
Schema = kernelFunction.Metadata.ReturnParameter.Schema?.RootElement,
},
};
}

public override AIFunctionMetadata Metadata { get; }

protected override async Task<object?> InvokeCoreAsync(IEnumerable<KeyValuePair<string, object?>> arguments, CancellationToken cancellationToken)
{
Verify.NotNull(arguments);

KernelArguments args = [];
foreach (var argument in arguments)
{
args[argument.Key] = argument.Value;
}

var functionResult = await this._kernelFunction.InvokeAsync(this._kernel ?? new(), args, cancellationToken).ConfigureAwait(false);
return functionResult.Value is object value ? JsonSerializer.SerializeToElement(
value,
AbstractionsJsonContext.GetTypeInfo(value.GetType(), this._kernelFunction.JsonSerializerOptions)) : null;
}
}
}
12 changes: 12 additions & 0 deletions dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.Extensions.AI;

#pragma warning disable CA1716 // Identifiers should not match keywords

Expand Down Expand Up @@ -92,6 +93,17 @@ public IList<KernelFunctionMetadata> GetFunctionsMetadata()
/// <inheritdoc/>
public abstract IEnumerator<KernelFunction> GetEnumerator();

/// <summary>Produces an <see cref="AIFunction"/> for every <see cref="KernelFunction"/> in this plugin.</summary>
/// <returns>An enumerable of <see cref="AIFunction"/> instances, one for each <see cref="KernelFunction"/> in this plugin.</returns>
[Experimental("SKEXP0001")]
public IEnumerable<AIFunction> AsAIFunctions()
{
foreach (KernelFunction function in this)
{
yield return function.AsAIFunction();
}
}

/// <inheritdoc/>
IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<AssemblyName>Microsoft.SemanticKernel.Abstractions</AssemblyName>
<RootNamespace>Microsoft.SemanticKernel</RootNamespace>
<TargetFrameworks>net8.0;netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);SKEXP0001;NU5104;SKEXP0120</NoWarn>
<NoWarn>$(NoWarn);NU5104;SKEXP0001;NU5104;SKEXP0120</NoWarn>
<EnablePackageValidation>true</EnablePackageValidation>
<IsAotCompatible Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net7.0'))">true</IsAotCompatible>
</PropertyGroup>
Expand All @@ -29,6 +29,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="Microsoft.Bcl.HashCode" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="System.Diagnostics.DiagnosticSource" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Embeddings;
using Xunit;

namespace SemanticKernel.UnitTests.AI;

public class ServiceConversionExtensionsTests
{
[Fact]
public void InvalidArgumentsThrow()
{
Assert.Throws<ArgumentNullException>("service", () => ChatCompletionServiceExtensions.AsChatClient(null!));
Assert.Throws<ArgumentNullException>("client", () => ChatCompletionServiceExtensions.AsChatCompletionService(null!));

Assert.Throws<ArgumentNullException>("service", () => EmbeddingGenerationExtensions.AsEmbeddingGenerator<string, float>(null!));
Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGenerationExtensions.AsEmbeddingGenerationService<string, float>(null!));
}
}

0 comments on commit 11fbef3

Please sign in to comment.