Skip to content

Commit

Permalink
Add an SK embeddings generator (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveSandersonMS authored Mar 15, 2024
1 parent 348412d commit 02dd2d1
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
<PackageVersion Include="Microsoft.AspNetCore.Components.WebAssembly.Server" Version="8.0.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.11" />
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.17.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Abstractions" Version="1.6.2" />
<PackageVersion Include="Microsoft.SemanticKernel.Core" Version="1.6.2" />
<PackageVersion Include="System.Numerics.Tensors" Version="8.0.0" />
<PackageVersion Include="System.Runtime.Caching" Version="8.0.0" />
<PackageVersion Include="System.Text.Json" Version="8.0.1" />
Expand Down
14 changes: 14 additions & 0 deletions SmartComponents.sln
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestBlazorServerNet6App", "
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SmartComponents.E2ETest.BlazorNet6", "test\SmartComponents.E2ETest.BlazorNet6\SmartComponents.E2ETest.BlazorNet6.csproj", "{7A919A92-A121-420B-9E18-47A60DCDAA69}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SmartComponents.LocalEmbeddings.SemanticKernel", "src\SmartComponents.LocalEmbeddings.SemanticKernel\SmartComponents.LocalEmbeddings.SemanticKernel.csproj", "{A31FD69E-2744-4800-AA7E-D734E8737715}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SmartComponents.LocalEmbeddings.SemanticKernel.Test", "test\SmartComponents.LocalEmbeddings.SemanticKernel.Test\SmartComponents.LocalEmbeddings.SemanticKernel.Test.csproj", "{23031658-179A-4425-82F2-29290DE4F3B2}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -127,6 +131,14 @@ Global
{7A919A92-A121-420B-9E18-47A60DCDAA69}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7A919A92-A121-420B-9E18-47A60DCDAA69}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7A919A92-A121-420B-9E18-47A60DCDAA69}.Release|Any CPU.Build.0 = Release|Any CPU
{A31FD69E-2744-4800-AA7E-D734E8737715}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A31FD69E-2744-4800-AA7E-D734E8737715}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A31FD69E-2744-4800-AA7E-D734E8737715}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A31FD69E-2744-4800-AA7E-D734E8737715}.Release|Any CPU.Build.0 = Release|Any CPU
{23031658-179A-4425-82F2-29290DE4F3B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{23031658-179A-4425-82F2-29290DE4F3B2}.Debug|Any CPU.Build.0 = Debug|Any CPU
{23031658-179A-4425-82F2-29290DE4F3B2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{23031658-179A-4425-82F2-29290DE4F3B2}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -151,6 +163,8 @@ Global
{CB0537FA-53A2-4470-A5CA-423C7D09EFC5} = {04F66920-45C0-4410-89ED-F2B5E6223958}
{F8C57083-620C-4D8F-8366-60E06593F720} = {7A830C0D-7E18-4674-A729-726085D9C0D1}
{7A919A92-A121-420B-9E18-47A60DCDAA69} = {03710CDB-ACD6-4712-95C8-B780EEEFAA29}
{A31FD69E-2744-4800-AA7E-D734E8737715} = {B1370349-29FA-49A1-A229-A31F7516A1FF}
{23031658-179A-4425-82F2-29290DE4F3B2} = {03710CDB-ACD6-4712-95C8-B780EEEFAA29}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {9D22AAE3-5C4E-4636-8FE0-FD175D10A3CA}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Embeddings;

namespace SmartComponents.LocalEmbeddings.SemanticKernel;

/// <summary>
/// A text embedding service that computes embeddings locally using <see cref="LocalEmbedder"/>.
/// </summary>
public class LocalTextEmbeddingGenerationService : ITextEmbeddingGenerationService, IDisposable
{
private readonly LocalEmbedder _embedder;
private readonly int _maximumTokens;

/// <summary>
/// Constructs an instance of <see cref="LocalTextEmbeddingGenerationService"/>.
/// </summary>
/// <param name="modelName">The name of the model to load. See documentation for <see cref="LocalEmbedder"/>.</param>
/// <param name="caseSensitive">True if text should be handled case sensitively, otherwise false.</param>
/// <param name="maximumTokens">The maximum number of tokens to include in the generated embeddings. This limits the amount of processing by truncating longer strings when the limit is reached..</param>
public LocalTextEmbeddingGenerationService(string? modelName = default, bool caseSensitive = false, int maximumTokens = 512)
{
_embedder = new(modelName ?? "default", caseSensitive);
_maximumTokens = maximumTokens;
}

// Attributes is unused
private static readonly IReadOnlyDictionary<string, object?> _emptyDict = new Dictionary<string, object?>();

/// <inheritdoc />
public IReadOnlyDictionary<string, object?> Attributes => _emptyDict;

/// <inheritdoc />
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var results = new ReadOnlyMemory<float>[data.Count];
for (var i = 0; i < data.Count; i++)
{
results[i] = _embedder.Embed(data[i], _maximumTokens).Values;
}

return Task.FromResult((IList<ReadOnlyMemory<float>>)results);
}

/// <inheritdoc />
public void Dispose()
=> _embedder.Dispose();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel.Embeddings;
using SmartComponents.LocalEmbeddings.SemanticKernel;

namespace Microsoft.SemanticKernel;

public static class LocalTextEmbeddingKernelBuilderServiceCollectionExtensions
{
/// <summary>
/// Adds a local text embedding generation service.
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="modelName">The name of the model to load. See documentation for <see cref="LocalEmbedder"/>.</param>
/// <param name="caseSensitive">True if text should be handled case sensitively, otherwise false.</param>
/// <param name="maximumTokens">The maximum number of tokens to include in the generated embeddings. This limits the amount of processing by truncating longer strings when the limit is reached.</param>
/// <returns>The <paramref name="builder"/>.</returns>
public static IKernelBuilder AddLocalTextEmbeddingGeneration(
this IKernelBuilder builder,
string? modelName = default,
bool caseSensitive = false,
int maximumTokens = 512)
{
var instance = new LocalTextEmbeddingGenerationService(modelName, caseSensitive, maximumTokens);
builder.Services.AddSingleton<ITextEmbeddingGenerationService>(instance);
return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<NoWarn>$(NoWarn);SKEXP0001</NoWarn>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\SmartComponents.LocalEmbeddings\SmartComponents.LocalEmbeddings.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Numerics.Tensors;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Embeddings;

namespace SmartComponents.LocalEmbeddings.SemanticKernel.Test;

public class EmbeddingsTest
{
[Fact]
public async Task CanComputeEmbeddings()
{
var builder = Kernel.CreateBuilder();
builder.AddLocalTextEmbeddingGeneration();
var kernel = builder.Build();

var embeddingGenerator = kernel.Services.GetRequiredService<ITextEmbeddingGenerationService>();

var cat = await embeddingGenerator.GenerateEmbeddingAsync("cat");
string[] sentences = [
"dog",
"kitten!",
"Cats are good",
"Cats are bad",
"Tiger",
"Wolf",
"Grimsby Town FC",
"Elephants are here",
];
var sentenceEmbeddings = await embeddingGenerator.GenerateEmbeddingsAsync(sentences);
var sentencesWithEmbeddings = sentences.Zip(sentenceEmbeddings, (s, e) => (Sentence: s, Embedding: e)).ToArray();

var sentencesRankedBySimilarity = sentencesWithEmbeddings
.OrderByDescending(s => TensorPrimitives.CosineSimilarity(cat.Span, s.Embedding.Span))
.Select(s => s.Sentence)
.ToArray();

Assert.Equal([
"Cats are good",
"kitten!",
"Cats are bad",
"Tiger",
"dog",
"Wolf",
"Elephants are here",
"Grimsby Town FC",
], sentencesRankedBySimilarity);
}

[Fact]
public async Task IsCaseInsensitiveByDefault()
{
var builder = Kernel.CreateBuilder();
builder.AddLocalTextEmbeddingGeneration();
var kernel = builder.Build();

var embeddingGenerator = kernel.Services.GetRequiredService<ITextEmbeddingGenerationService>();
var catLower = await embeddingGenerator.GenerateEmbeddingAsync("cat");
var catUpper = await embeddingGenerator.GenerateEmbeddingAsync("CAT");
var similarity = TensorPrimitives.CosineSimilarity(catLower.Span, catUpper.Span);
Assert.Equal(1, MathF.Round(similarity, 3));
}

[Fact]
public async Task CanBeConfiguredAsCaseSensitive()
{
var builder = Kernel.CreateBuilder();
builder.AddLocalTextEmbeddingGeneration(caseSensitive: true);
var kernel = builder.Build();

var embeddingGenerator = kernel.Services.GetRequiredService<ITextEmbeddingGenerationService>();
var catLower = await embeddingGenerator.GenerateEmbeddingAsync("cat");
var catUpper = await embeddingGenerator.GenerateEmbeddingAsync("CAT");
var similarity = TensorPrimitives.CosineSimilarity(catLower.Span, catUpper.Span);
Assert.NotEqual(1, MathF.Round(similarity, 3));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
global using Xunit;
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>

<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
<NoWarn>$(NoWarn);SKEXP0001</NoWarn>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Microsoft.SemanticKernel.Core" />
<PackageReference Include="xunit" />
<PackageReference Include="xunit.runner.visualstudio">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\SmartComponents.LocalEmbeddings.SemanticKernel\SmartComponents.LocalEmbeddings.SemanticKernel.csproj" />
</ItemGroup>

<!-- Only needed when referencing the dependencies as projects. For package references, these are imported automatically. -->
<Import Project="$(RepoRoot)src\SmartComponents.LocalEmbeddings\build\SmartComponents.LocalEmbeddings.targets" />

</Project>

0 comments on commit 02dd2d1

Please sign in to comment.