Skip to content

Commit

Permalink
.Net: Fix bug where redis score was mapped from wrong score field. (#…
Browse files Browse the repository at this point in the history
…9901)

### Motivation and Context

#9900

### Description

Fixing mapping issue.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
westey-m authored Dec 6, 2024
1 parent 5e5de6e commit 049cbbf
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ public sealed class RedisHashSetVectorStoreRecordCollection<TRecord> : IVectorSt
/// <summary>An array of the names of all the data properties that are part of the Redis payload as RedisValue objects, i.e. all properties except the key and vector properties.</summary>
private readonly RedisValue[] _dataStoragePropertyNameRedisValues;

/// <summary>An array of the names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties.</summary>
private readonly string[] _dataStoragePropertyNames;
/// <summary>An array of the names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties, plus the generated score property.</summary>
private readonly string[] _dataStoragePropertyNamesWithScore;

/// <summary>The mapper to use when mapping between the consumer data model and the Redis record.</summary>
private readonly IVectorStoreRecordMapper<TRecord, (string Key, HashEntry[] HashEntries)> _mapper;
Expand Down Expand Up @@ -119,14 +119,12 @@ public RedisHashSetVectorStoreRecordCollection(IDatabase database, string collec
this._propertyReader.VerifyVectorProperties(s_supportedVectorTypes);

// Lookup storage property names.
this._dataStoragePropertyNames = this._propertyReader
.DataPropertyStoragePropertyNames
.ToArray();

this._dataStoragePropertyNameRedisValues = this._dataStoragePropertyNames
this._dataStoragePropertyNameRedisValues = this._propertyReader.DataPropertyStoragePropertyNames
.Select(RedisValue.Unbox)
.ToArray();

this._dataStoragePropertyNamesWithScore = [.. this._propertyReader.DataPropertyStoragePropertyNames, "vector_score"];

// Assign Mapper.
if (this._options.HashEntriesCustomMapper is not null)
{
Expand Down Expand Up @@ -342,7 +340,7 @@ public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(T
var internalOptions = options ?? s_defaultVectorSearchOptions;

// Build query & search.
var selectFields = internalOptions.IncludeVectors ? null : this._dataStoragePropertyNames;
var selectFields = internalOptions.IncludeVectors ? null : this._dataStoragePropertyNamesWithScore;
byte[] vectorBytes = RedisVectorStoreCollectionSearchMapping.ValidateVectorAndConvertToBytes(vector, "HashSet");
var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(vectorBytes, internalOptions, this._propertyReader.StoragePropertyNamesMap, this._propertyReader.FirstVectorPropertyStoragePropertyName!, selectFields);
var results = await this.RunOperationAsync(
Expand All @@ -369,7 +367,11 @@ public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(T
return this._mapper.MapFromStorageToDataModel((this.RemoveKeyPrefixIfNeeded(result.Id), retrievedHashEntries), new() { IncludeVectors = internalOptions.IncludeVectors });
});

return new VectorSearchResult<TRecord>(dataModel, result.Score);
// Process the score of the result item.
var distanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(internalOptions, this._propertyReader.VectorProperties, this._propertyReader.VectorProperty!);
var score = RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(result["vector_score"].HasValue ? (float)result["vector_score"] : null, distanceFunction);

return new VectorSearchResult<TRecord>(dataModel, score);
});

return new VectorSearchResults<TRecord>(mappedResults.ToAsyncEnumerable());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,11 @@ public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(T
new() { IncludeVectors = internalOptions.IncludeVectors });
});

return new VectorSearchResult<TRecord>(mappedRecord, result.Score);
// Process the score of the result item.
var distanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(internalOptions, this._propertyReader.VectorProperties, this._propertyReader.VectorProperty!);
var score = RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(result["vector_score"].HasValue ? (float)result["vector_score"] : null, distanceFunction);

return new VectorSearchResult<TRecord>(mappedRecord, score);
});

return new VectorSearchResults<TRecord>(mappedResults.ToAsyncEnumerable());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ public static string GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vec
return vectorProperty.DistanceFunction switch
{
DistanceFunction.CosineSimilarity => "COSINE",
DistanceFunction.CosineDistance => "COSINE",
DistanceFunction.DotProductSimilarity => "IP",
DistanceFunction.EuclideanDistance => "L2",
_ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,53 @@ public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR
return $"({string.Join(" ", filterClauses)})";
}

/// <summary>
/// Resolve the distance function to use for a search by checking the distance function of the vector property specified in options
/// or by falling back to the distance function of the first vector property, or by falling back to the default distance function.
/// </summary>
/// <param name="options">The search options potentially containing a vector field to search.</param>
/// <param name="vectorProperties">The list of all vector properties.</param>
/// <param name="firstVectorProperty">The first vector property in the record.</param>
/// <returns>The distance function for the vector we want to search.</returns>
/// <exception cref="InvalidOperationException">Thrown when a user asked for a vector property that doesn't exist on the record.</exception>
public static string ResolveDistanceFunction(VectorSearchOptions options, IReadOnlyList<VectorStoreRecordVectorProperty> vectorProperties, VectorStoreRecordVectorProperty firstVectorProperty)
{
if (options.VectorPropertyName == null || vectorProperties.Count == 1)
{
return firstVectorProperty.DistanceFunction ?? DistanceFunction.CosineSimilarity;
}

var vectorProperty = vectorProperties.FirstOrDefault(p => p.DataModelPropertyName == options.VectorPropertyName)
?? throw new InvalidOperationException($"The collection does not have a vector field named '{options.VectorPropertyName}'.");

return vectorProperty.DistanceFunction ?? DistanceFunction.CosineSimilarity;
}

/// <summary>
/// Convert the score from redis into the appropriate output score based on the distance function.
/// Redis doesn't support Cosine Similarity, so we need to convert from distance to similarity if it was chosen.
/// </summary>
/// <param name="redisScore">The redis score to convert.</param>
/// <param name="distanceFunction">The distance function used in the search.</param>
/// <returns>The converted score.</returns>
/// <exception cref="InvalidOperationException">Thrown if the provided distance function is not supported by redis.</exception>
public static float? GetOutputScoreFromRedisScore(float? redisScore, string distanceFunction)
{
if (redisScore is null)
{
return null;
}

return distanceFunction switch
{
DistanceFunction.CosineSimilarity => 1 - redisScore,
DistanceFunction.CosineDistance => redisScore,
DistanceFunction.DotProductSimilarity => redisScore,
DistanceFunction.EuclideanDistance => redisScore,
_ => throw new InvalidOperationException($"The distance function '{distanceFunction}' is not supported."),
};
}

/// <summary>
/// Resolve the vector field name to use for a search by using the storage name for the field name from options
/// if available, and falling back to the first vector field name if not.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc
{
RedisResult.Create(new RedisValue("1")),
RedisResult.Create(new RedisValue(TestRecordKey1)),
RedisResult.Create(new RedisValue("0.5")),
RedisResult.Create(new RedisValue("0.8")),
RedisResult.Create(
[
new RedisValue("OriginalNameData"),
Expand All @@ -436,6 +436,8 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc
new RedisValue("data 1"),
new RedisValue("vector_storage_name"),
RedisValue.Unbox(MemoryMarshal.AsBytes(new ReadOnlySpan<float>(new float[] { 1, 2, 3, 4 })).ToArray()),
new RedisValue("vector_score"),
new RedisValue("0.25"),
]),
});
var sut = this.CreateRecordCollection(useDefinition);
Expand Down Expand Up @@ -468,9 +470,10 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc
var returnArgs = includeVectors ? Array.Empty<object>() : new object[]
{
"RETURN",
2,
3,
"OriginalNameData",
"data_storage_name"
"data_storage_name",
"vector_score"
};
var expectedArgsPart2 = new object[]
{
Expand All @@ -493,7 +496,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc
var results = await actual.Results.ToListAsync();
Assert.Single(results);
Assert.Equal(TestRecordKey1, results.First().Record.Key);
Assert.Equal(0.5d, results.First().Score);
Assert.Equal(0.25d, results.First().Score);
Assert.Equal("original data 1", results.First().Record.OriginalNameData);
Assert.Equal("data 1", results.First().Record.Data);
if (includeVectors)
Expand Down Expand Up @@ -613,7 +616,7 @@ private static SinglePropsModel CreateModel(string key, bool withVectors)
new VectorStoreRecordKeyProperty("Key", typeof(string)),
new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)),
new VectorStoreRecordDataProperty("Data", typeof(string)) { StoragePropertyName = "data_storage_name" },
new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory<float>)) { StoragePropertyName = "vector_storage_name" }
new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory<float>)) { StoragePropertyName = "vector_storage_name", DistanceFunction = DistanceFunction.CosineDistance }
]
};

Expand All @@ -630,7 +633,7 @@ public sealed class SinglePropsModel
public string Data { get; set; } = string.Empty;

[JsonPropertyName("ignored_vector_json_name")]
[VectorStoreRecordVector(4, StoragePropertyName = "vector_storage_name")]
[VectorStoreRecordVector(4, DistanceFunction.CosineDistance, StoragePropertyName = "vector_storage_name")]
public ReadOnlyMemory<float>? Vector { get; set; }

public string? NotAnnotated { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,14 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition)
{
RedisResult.Create(new RedisValue("1")),
RedisResult.Create(new RedisValue(TestRecordKey1)),
RedisResult.Create(new RedisValue("0.5")),
RedisResult.Create([new RedisValue("$"), new RedisValue(jsonResult)]),
RedisResult.Create(new RedisValue("0.8")),
RedisResult.Create(
[
new RedisValue("$"),
new RedisValue(jsonResult),
new RedisValue("vector_score"),
new RedisValue("0.25")
]),
});
var sut = this.CreateRecordCollection(useDefinition);

Expand Down Expand Up @@ -496,7 +502,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition)
var results = await actual.Results.ToListAsync();
Assert.Single(results);
Assert.Equal(TestRecordKey1, results.First().Record.Key);
Assert.Equal(0.5d, results.First().Score);
Assert.Equal(0.25d, results.First().Score);
Assert.Equal("data 1", results.First().Record.Data1);
Assert.Equal("data 2", results.First().Record.Data2);
Assert.Equal(new float[] { 1, 2, 3, 4 }, results.First().Record.Vector1!.Value.ToArray());
Expand Down Expand Up @@ -617,7 +623,7 @@ private static MultiPropsModel CreateModel(string key, bool withVectors)
new VectorStoreRecordKeyProperty("Key", typeof(string)),
new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFilterable = true, StoragePropertyName = "ignored_data1_storage_name" },
new VectorStoreRecordDataProperty("Data2", typeof(string)) { IsFilterable = true },
new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory<float>)) { Dimensions = 4, StoragePropertyName = "ignored_vector1_storage_name" },
new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory<float>)) { Dimensions = 4, DistanceFunction = DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name" },
new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory<float>)) { Dimensions = 4 }
]
};
Expand All @@ -635,7 +641,7 @@ public sealed class MultiPropsModel
public string Data2 { get; set; } = string.Empty;

[JsonPropertyName("vector1_json_name")]
[VectorStoreRecordVector(4, StoragePropertyName = "ignored_vector1_storage_name")]
[VectorStoreRecordVector(4, DistanceFunction.CosineDistance, StoragePropertyName = "ignored_vector1_storage_name")]
public ReadOnlyMemory<float>? Vector1 { get; set; }

[VectorStoreRecordVector(4)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,60 @@ public void BuildFilterThrowsForUnknownFieldName()
var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames);
});
}

[Fact]
public void ResolveDistanceFunctionReturnsCosineSimilarityIfNoDistanceFunctionSpecified()
{
var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory<float>));

// Act.
var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property);

// Assert.
Assert.Equal(DistanceFunction.CosineSimilarity, resolvedDistanceFunction);
}

[Fact]
public void ResolveDistanceFunctionReturnsDistanceFunctionFromFirstPropertyIfNoFieldChosen()
{
var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory<float>)) { DistanceFunction = DistanceFunction.DotProductSimilarity };

// Act.
var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property);

// Assert.
Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction);
}

[Fact]
public void ResolveDistanceFunctionReturnsDistanceFunctionFromChosenPropertyIfFieldChosen()
{
var property1 = new VectorStoreRecordVectorProperty("Prop1", typeof(ReadOnlyMemory<float>)) { DistanceFunction = DistanceFunction.CosineDistance };
var property2 = new VectorStoreRecordVectorProperty("Prop2", typeof(ReadOnlyMemory<float>)) { DistanceFunction = DistanceFunction.DotProductSimilarity };

// Act.
var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions() { VectorPropertyName = "Prop2" }, [property1, property2], property1);

// Assert.
Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction);
}

[Fact]
public void GetOutputScoreFromRedisScoreConvertsCosineDistanceToSimilarity()
{
// Act & Assert.
Assert.Equal(-1, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(2, DistanceFunction.CosineSimilarity));
Assert.Equal(0, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(1, DistanceFunction.CosineSimilarity));
Assert.Equal(1, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(0, DistanceFunction.CosineSimilarity));
}

[Theory]
[InlineData(DistanceFunction.CosineDistance, 2)]
[InlineData(DistanceFunction.DotProductSimilarity, 2)]
[InlineData(DistanceFunction.EuclideanDistance, 2)]
public void GetOutputScoreFromRedisScoreLeavesNonConsineSimilarityUntouched(string distanceFunction, float score)
{
// Act & Assert.
Assert.Equal(score, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(score, distanceFunction));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe

var searchResults = await actual.Results.ToListAsync();
Assert.Single(searchResults);
Assert.Equal(1, searchResults.First().Score);
var searchResultRecord = searchResults.First().Record;
Assert.Equal(record.HotelId, searchResultRecord?.HotelId);
Assert.Equal(record.HotelName, searchResultRecord?.HotelName);
Expand Down Expand Up @@ -325,6 +326,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType,
// Assert
var searchResults = await actual.Results.ToListAsync();
Assert.Single(searchResults);
Assert.Equal(1, searchResults.First().Score);
var searchResult = searchResults.First().Record;
Assert.Equal("HBaseSet-1", searchResult?.HotelId);
Assert.Equal("My Hotel 1", searchResult?.HotelName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe

var searchResults = await actual.Results.ToListAsync();
Assert.Single(searchResults);
Assert.Equal(1, searchResults.First().Score);
var searchResultRecord = searchResults.First().Record;
Assert.Equal(record.HotelId, searchResultRecord?.HotelId);
Assert.Equal(record.HotelName, searchResultRecord?.HotelName);
Expand Down Expand Up @@ -351,6 +352,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType)
// Assert
var searchResults = await actual.Results.ToListAsync();
Assert.Single(searchResults);
Assert.Equal(1, searchResults.First().Score);
var searchResult = searchResults.First().Record;
Assert.Equal("My Hotel 1", searchResults.First().Record.HotelName);
Assert.Equal("BaseSet-1", searchResult?.HotelId);
Expand Down

0 comments on commit 049cbbf

Please sign in to comment.