-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
utils.py
292 lines (225 loc) · 10.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright (c) Microsoft. All rights reserved.
from typing import TYPE_CHECKING, Any
from weaviate.classes.config import Configure, Property
from weaviate.classes.query import Filter
from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate
from weaviate.collections.classes.config_vector_index import _VectorIndexConfigCreate
from weaviate.collections.classes.config_vectorizers import VectorDistances
from semantic_kernel.connectors.memory.weaviate.const import TYPE_MAPPER_DATA
from semantic_kernel.data.const import DistanceFunction, IndexKind
from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo
from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
from semantic_kernel.data.record_definition.vector_store_record_fields import (
VectorStoreRecordDataField,
VectorStoreRecordVectorField,
)
from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter
from semantic_kernel.exceptions.memory_connector_exceptions import (
VectorStoreModelDeserializationException,
)
if TYPE_CHECKING:
from weaviate.collections.classes.filters import _Filters
def data_model_definition_to_weaviate_properties(
data_model_definition: VectorStoreRecordDefinition,
) -> list[Property]:
"""Convert vector store data fields to Weaviate properties.
Args:
data_model_definition (VectorStoreRecordDefinition): The data model definition.
Returns:
list[Property]: The Weaviate properties.
"""
properties: list[Property] = []
for field in data_model_definition.fields.values():
if isinstance(field, VectorStoreRecordDataField):
properties.append(
Property(
name=field.name,
data_type=TYPE_MAPPER_DATA[field.property_type or "default"],
index_filterable=field.is_filterable,
index_full_text=field.is_full_text_searchable,
)
)
return properties
def data_model_definition_to_weaviate_named_vectors(
data_model_definition: VectorStoreRecordDefinition,
) -> list[_NamedVectorConfigCreate]:
"""Convert vector store vector fields to Weaviate named vectors.
Args:
data_model_definition (VectorStoreRecordDefinition): The data model definition.
Returns:
list[_NamedVectorConfigCreate]: The Weaviate named vectors.
"""
vector_list: list[_NamedVectorConfigCreate] = []
for vector_field in data_model_definition.vector_fields:
vector_list.append(
Configure.NamedVectors.none(
name=vector_field.name, # type: ignore
vector_index_config=to_weaviate_vector_index_config(vector_field),
)
)
return vector_list
def to_weaviate_vector_index_config(vector: VectorStoreRecordVectorField) -> _VectorIndexConfigCreate:
"""Convert a vector field to a Weaviate vector index configuration.
Args:
vector (VectorStoreRecordVectorField): The vector field.
Returns:
The Weaviate vector index configuration.
"""
if vector.index_kind == IndexKind.HNSW:
return Configure.VectorIndex.hnsw(
distance_metric=to_weaviate_vector_distance(vector.distance_function),
)
if vector.index_kind == IndexKind.FLAT:
return Configure.VectorIndex.flat(
distance_metric=to_weaviate_vector_distance(vector.distance_function),
)
return Configure.VectorIndex.none()
def to_weaviate_vector_distance(distance_function: DistanceFunction | None) -> VectorDistances | None:
"""Convert a distance function to a Weaviate vector distance metric.
Args:
distance_function (DistanceFunction | None): The distance function.
Returns:
str: The Weaviate vector distance metric name.
"""
match distance_function:
case DistanceFunction.COSINE_DISTANCE:
return VectorDistances.COSINE
case DistanceFunction.DOT_PROD:
return VectorDistances.DOT
case DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE:
return VectorDistances.L2_SQUARED
case DistanceFunction.MANHATTAN:
return VectorDistances.MANHATTAN
case DistanceFunction.HAMMING:
return VectorDistances.HAMMING
raise ValueError(f"Unsupported distance function for Weaviate: {distance_function}")
# region Serialization helpers
def extract_properties_from_dict_record_based_on_data_model_definition(
record: dict[str, Any],
data_model_definition: VectorStoreRecordDefinition,
) -> dict[str, list[float]] | list[float]:
"""Extract Weaviate object properties from a dictionary record based on the data model definition.
Expecting the record to have all the data fields defined in the data model definition.
The returned object can be used to construct a Weaviate object.
Args:
record (dict[str, Any]): The record.
data_model_definition (VectorStoreRecordDefinition): The data model definition.
Returns:
dict[str, Any]: The extra properties.
"""
return {
field.name: record[field.name]
for field in data_model_definition.fields.values()
if isinstance(field, VectorStoreRecordDataField) and field.name
}
def extract_key_from_dict_record_based_on_data_model_definition(
record: dict[str, Any],
data_model_definition: VectorStoreRecordDefinition,
) -> str | None:
"""Extract Weaviate object key from a dictionary record based on the data model definition.
Expecting the record to have a key field defined in the data model definition.
The returned object can be used to construct a Weaviate object.
The key maps to a Weaviate object ID.
Args:
record (dict[str, Any]): The record.
data_model_definition (VectorStoreRecordDefinition): The data model definition.
Returns:
str: The key.
"""
return record[data_model_definition.key_field.name] if data_model_definition.key_field.name else None
def extract_vectors_from_dict_record_based_on_data_model_definition(
record: dict[str, Any],
data_model_definition: VectorStoreRecordDefinition,
named_vectors: bool,
) -> dict[str, Any] | Any | None:
"""Extract Weaviate object vectors from a dictionary record based on the data model definition.
By default a collection is set to use named vectors, this means that the name of the vector field is
added before the value, otherwise it is just the value and there can only be one vector in that case.
The returned object can be used to construct a Weaviate object.
Args:
record (dict[str, Any]): The record.
data_model_definition (VectorStoreRecordDefinition): The data model definition.
named_vectors (bool): Whether to use named vectors.
Returns:
dict[str, Any]: The vectors.
"""
if named_vectors:
return {vector.name: record[vector.name] for vector in data_model_definition.vector_fields}
return record[data_model_definition.vector_fields[0].name] if data_model_definition.vector_fields else None
# endregion
# region Deserialization helpers
def extract_properties_from_weaviate_object_based_on_data_model_definition(
weaviate_object,
data_model_definition: VectorStoreRecordDefinition,
) -> dict[str, Any]:
"""Extract data model properties from a Weaviate object based on the data model definition.
Expecting the Weaviate object to have all the properties defined in the data model definition.
Args:
weaviate_object: The Weaviate object.
data_model_definition (VectorStoreRecordDefinition): The data model definition.
Returns:
dict[str, Any]: The data model properties.
"""
return {
field.name: weaviate_object.properties[field.name]
for field in data_model_definition.fields.values()
if isinstance(field, VectorStoreRecordDataField) and field.name in weaviate_object.properties
}
def extract_key_from_weaviate_object_based_on_data_model_definition(
weaviate_object,
data_model_definition: VectorStoreRecordDefinition,
) -> dict[str, str]:
"""Extract data model key from a Weaviate object based on the data model definition.
Expecting the Weaviate object to have an id defined.
Args:
weaviate_object: The Weaviate object.
data_model_definition (VectorStoreRecordDefinition): The data model definition.
Returns:
str: The key.
"""
if data_model_definition.key_field.name and weaviate_object.uuid:
return {data_model_definition.key_field.name: weaviate_object.uuid}
# This is not supposed to happen
raise VectorStoreModelDeserializationException("Unable to extract id/key from Weaviate store model")
def extract_vectors_from_weaviate_object_based_on_data_model_definition(
weaviate_object,
data_model_definition: VectorStoreRecordDefinition,
named_vectors: bool,
) -> dict[str, Any]:
"""Extract vectors from a Weaviate object based on the data model definition.
Args:
weaviate_object: The Weaviate object.
data_model_definition (VectorStoreRecordDefinition): The data model definition.
named_vectors (bool): Whether the collection uses named vectors.
Returns:
dict[str, Any]: The vectors, or None.
"""
if not weaviate_object.vector:
return {}
if named_vectors:
return {
vector.name: weaviate_object.vector[vector.name]
for vector in data_model_definition.vector_fields
if vector.name in weaviate_object.vector
}
vector_field = data_model_definition.vector_fields[0] if data_model_definition.vector_fields else None
if not vector_field:
return {}
return {vector_field.name: weaviate_object.vector["default"]}
# endregion
# region VectorSearch helpers
def create_filter_from_vector_search_filters(filters: VectorSearchFilter | None) -> "_Filters | None":
"""Create a Weaviate filter from a vector search filter."""
if not filters:
return None
weaviate_filters: list["_Filters"] = []
for filter in filters.filters:
match filter:
case EqualTo():
weaviate_filters.append(Filter.by_property(filter.field_name).equal(filter.value))
case AnyTagsEqualTo():
weaviate_filters.append(Filter.by_property(filter.field_name).like(filter.value))
case _:
raise ValueError(f"Unsupported filter type: {filter}")
return Filter.all_of(weaviate_filters) if weaviate_filters else None