-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an SK embeddings generator (#10)
- Loading branch information
1 parent
348412d
commit 02dd2d1
Showing
8 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
src/SmartComponents.LocalEmbeddings.SemanticKernel/LocalTextEmbeddingGenerationService.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using Microsoft.SemanticKernel; | ||
using Microsoft.SemanticKernel.Embeddings; | ||
|
||
namespace SmartComponents.LocalEmbeddings.SemanticKernel; | ||
|
||
/// <summary> | ||
/// A text embedding service that computes embeddings locally using <see cref="LocalEmbedder"/>. | ||
/// </summary> | ||
public class LocalTextEmbeddingGenerationService : ITextEmbeddingGenerationService, IDisposable | ||
{ | ||
private readonly LocalEmbedder _embedder; | ||
private readonly int _maximumTokens; | ||
|
||
/// <summary> | ||
/// Constructs an instance of <see cref="LocalTextEmbeddingGenerationService"/>. | ||
/// </summary> | ||
/// <param name="modelName">The name of the model to load. See documentation for <see cref="LocalEmbedder"/>.</param> | ||
/// <param name="caseSensitive">True if text should be handled case sensitively, otherwise false.</param> | ||
/// <param name="maximumTokens">The maximum number of tokens to include in the generated embeddings. This limits the amount of processing by truncating longer strings when the limit is reached..</param> | ||
public LocalTextEmbeddingGenerationService(string? modelName = default, bool caseSensitive = false, int maximumTokens = 512) | ||
{ | ||
_embedder = new(modelName ?? "default", caseSensitive); | ||
_maximumTokens = maximumTokens; | ||
} | ||
|
||
// Attributes is unused | ||
private static readonly IReadOnlyDictionary<string, object?> _emptyDict = new Dictionary<string, object?>(); | ||
|
||
/// <inheritdoc /> | ||
public IReadOnlyDictionary<string, object?> Attributes => _emptyDict; | ||
|
||
/// <inheritdoc /> | ||
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default) | ||
{ | ||
var results = new ReadOnlyMemory<float>[data.Count]; | ||
for (var i = 0; i < data.Count; i++) | ||
{ | ||
results[i] = _embedder.Embed(data[i], _maximumTokens).Values; | ||
} | ||
|
||
return Task.FromResult((IList<ReadOnlyMemory<float>>)results); | ||
} | ||
|
||
/// <inheritdoc /> | ||
public void Dispose() | ||
=> _embedder.Dispose(); | ||
} |
30 changes: 30 additions & 0 deletions
30
...alEmbeddings.SemanticKernel/LocalTextEmbeddingKernelBuilderServiceCollectionExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.SemanticKernel.Embeddings; | ||
using SmartComponents.LocalEmbeddings.SemanticKernel; | ||
|
||
namespace Microsoft.SemanticKernel; | ||
|
||
public static class LocalTextEmbeddingKernelBuilderServiceCollectionExtensions | ||
{ | ||
/// <summary> | ||
/// Adds a local text embedding generation service. | ||
/// </summary> | ||
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param> | ||
/// <param name="modelName">The name of the model to load. See documentation for <see cref="LocalEmbedder"/>.</param> | ||
/// <param name="caseSensitive">True if text should be handled case sensitively, otherwise false.</param> | ||
/// <param name="maximumTokens">The maximum number of tokens to include in the generated embeddings. This limits the amount of processing by truncating longer strings when the limit is reached.</param> | ||
/// <returns>The <paramref name="builder"/>.</returns> | ||
public static IKernelBuilder AddLocalTextEmbeddingGeneration( | ||
this IKernelBuilder builder, | ||
string? modelName = default, | ||
bool caseSensitive = false, | ||
int maximumTokens = 512) | ||
{ | ||
var instance = new LocalTextEmbeddingGenerationService(modelName, caseSensitive, maximumTokens); | ||
builder.Services.AddSingleton<ITextEmbeddingGenerationService>(instance); | ||
return builder; | ||
} | ||
} |
18 changes: 18 additions & 0 deletions
18
...ents.LocalEmbeddings.SemanticKernel/SmartComponents.LocalEmbeddings.SemanticKernel.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>net6.0</TargetFramework> | ||
<ImplicitUsings>enable</ImplicitUsings> | ||
<Nullable>enable</Nullable> | ||
<NoWarn>$(NoWarn);SKEXP0001</NoWarn> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\SmartComponents.LocalEmbeddings\SmartComponents.LocalEmbeddings.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
80 changes: 80 additions & 0 deletions
80
test/SmartComponents.LocalEmbeddings.SemanticKernel.Test/EmbeddingsTest.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using System.Numerics.Tensors; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.SemanticKernel; | ||
using Microsoft.SemanticKernel.Embeddings; | ||
|
||
namespace SmartComponents.LocalEmbeddings.SemanticKernel.Test; | ||
|
||
public class EmbeddingsTest | ||
{ | ||
[Fact] | ||
public async Task CanComputeEmbeddings() | ||
{ | ||
var builder = Kernel.CreateBuilder(); | ||
builder.AddLocalTextEmbeddingGeneration(); | ||
var kernel = builder.Build(); | ||
|
||
var embeddingGenerator = kernel.Services.GetRequiredService<ITextEmbeddingGenerationService>(); | ||
|
||
var cat = await embeddingGenerator.GenerateEmbeddingAsync("cat"); | ||
string[] sentences = [ | ||
"dog", | ||
"kitten!", | ||
"Cats are good", | ||
"Cats are bad", | ||
"Tiger", | ||
"Wolf", | ||
"Grimsby Town FC", | ||
"Elephants are here", | ||
]; | ||
var sentenceEmbeddings = await embeddingGenerator.GenerateEmbeddingsAsync(sentences); | ||
var sentencesWithEmbeddings = sentences.Zip(sentenceEmbeddings, (s, e) => (Sentence: s, Embedding: e)).ToArray(); | ||
|
||
var sentencesRankedBySimilarity = sentencesWithEmbeddings | ||
.OrderByDescending(s => TensorPrimitives.CosineSimilarity(cat.Span, s.Embedding.Span)) | ||
.Select(s => s.Sentence) | ||
.ToArray(); | ||
|
||
Assert.Equal([ | ||
"Cats are good", | ||
"kitten!", | ||
"Cats are bad", | ||
"Tiger", | ||
"dog", | ||
"Wolf", | ||
"Elephants are here", | ||
"Grimsby Town FC", | ||
], sentencesRankedBySimilarity); | ||
} | ||
|
||
[Fact] | ||
public async Task IsCaseInsensitiveByDefault() | ||
{ | ||
var builder = Kernel.CreateBuilder(); | ||
builder.AddLocalTextEmbeddingGeneration(); | ||
var kernel = builder.Build(); | ||
|
||
var embeddingGenerator = kernel.Services.GetRequiredService<ITextEmbeddingGenerationService>(); | ||
var catLower = await embeddingGenerator.GenerateEmbeddingAsync("cat"); | ||
var catUpper = await embeddingGenerator.GenerateEmbeddingAsync("CAT"); | ||
var similarity = TensorPrimitives.CosineSimilarity(catLower.Span, catUpper.Span); | ||
Assert.Equal(1, MathF.Round(similarity, 3)); | ||
} | ||
|
||
[Fact] | ||
public async Task CanBeConfiguredAsCaseSensitive() | ||
{ | ||
var builder = Kernel.CreateBuilder(); | ||
builder.AddLocalTextEmbeddingGeneration(caseSensitive: true); | ||
var kernel = builder.Build(); | ||
|
||
var embeddingGenerator = kernel.Services.GetRequiredService<ITextEmbeddingGenerationService>(); | ||
var catLower = await embeddingGenerator.GenerateEmbeddingAsync("cat"); | ||
var catUpper = await embeddingGenerator.GenerateEmbeddingAsync("CAT"); | ||
var similarity = TensorPrimitives.CosineSimilarity(catLower.Span, catUpper.Span); | ||
Assert.NotEqual(1, MathF.Round(similarity, 3)); | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
test/SmartComponents.LocalEmbeddings.SemanticKernel.Test/GlobalUsings.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
global using Xunit; |
30 changes: 30 additions & 0 deletions
30
...Embeddings.SemanticKernel.Test/SmartComponents.LocalEmbeddings.SemanticKernel.Test.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>net8.0</TargetFramework> | ||
<ImplicitUsings>enable</ImplicitUsings> | ||
<Nullable>enable</Nullable> | ||
|
||
<IsPackable>false</IsPackable> | ||
<IsTestProject>true</IsTestProject> | ||
<NoWarn>$(NoWarn);SKEXP0001</NoWarn> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.NET.Test.Sdk" /> | ||
<PackageReference Include="Microsoft.SemanticKernel.Core" /> | ||
<PackageReference Include="xunit" /> | ||
<PackageReference Include="xunit.runner.visualstudio"> | ||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||
<PrivateAssets>all</PrivateAssets> | ||
</PackageReference> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\src\SmartComponents.LocalEmbeddings.SemanticKernel\SmartComponents.LocalEmbeddings.SemanticKernel.csproj" /> | ||
</ItemGroup> | ||
|
||
<!-- Only needed when referencing the dependencies as projects. For package references, these are imported automatically. --> | ||
<Import Project="$(RepoRoot)src\SmartComponents.LocalEmbeddings\build\SmartComponents.LocalEmbeddings.targets" /> | ||
|
||
</Project> |