-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
weaviate_collection.py
286 lines (252 loc) · 12.1 KB
/
weaviate_collection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import Sequence
from typing import Any, TypeVar
if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover
import weaviate
from pydantic import field_validator
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from weaviate.collections.classes.data import DataObject
from weaviate.collections.collection import CollectionAsync
from weaviate.exceptions import WeaviateConnectionError
from semantic_kernel.connectors.memory.weaviate.utils import (
data_model_definition_to_weaviate_named_vectors,
data_model_definition_to_weaviate_properties,
extract_key_from_dict_record_based_on_data_model_definition,
extract_key_from_weaviate_object_based_on_data_model_definition,
extract_properties_from_dict_record_based_on_data_model_definition,
extract_properties_from_weaviate_object_based_on_data_model_definition,
extract_vectors_from_dict_record_based_on_data_model_definition,
extract_vectors_from_weaviate_object_based_on_data_model_definition,
)
from semantic_kernel.connectors.memory.weaviate.weaviate_settings import WeaviateSettings
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordDataField
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
from semantic_kernel.exceptions import (
MemoryConnectorConnectionException,
MemoryConnectorException,
MemoryConnectorInitializationError,
)
from semantic_kernel.kernel_types import OneOrMany
from semantic_kernel.utils.experimental_decorator import experimental_class
TModel = TypeVar("TModel")
TKey = TypeVar("TKey", str, int)
@experimental_class
class WeaviateCollection(VectorStoreRecordCollection[TKey, TModel]):
"""A Weaviate collection is a collection of records that are stored in a Weaviate database."""
async_client: weaviate.WeaviateAsyncClient
def __init__(
self,
data_model_type: type[TModel],
data_model_definition: VectorStoreRecordDefinition,
collection_name: str,
url: str | None = None,
api_key: str | None = None,
local_host: str | None = None,
local_port: int | None = None,
local_grpc_port: int | None = None,
use_embed: bool = False,
async_client: weaviate.WeaviateAsyncClient | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
):
"""Initialize a Weaviate collection.
Args:
data_model_type: The type of the data model.
data_model_definition: The definition of the data model.
collection_name: The name of the collection.
url: The Weaviate URL
api_key: The Weaviate API key.
local_host: The local Weaviate host (i.e. Weaviate in a Docker container).
local_port: The local Weaviate port.
local_grpc_port: The local Weaviate gRPC port.
use_embed: Whether to use the client embedding options.
async_client: A custom Weaviate async client.
env_file_path: The path to the environment file.
env_file_encoding: The encoding of the environment file.
"""
managed_client: bool = False
if not async_client:
managed_client = True
weaviate_settings = WeaviateSettings.create(
url=url,
api_key=api_key,
local_host=local_host,
local_port=local_port,
local_grpc_port=local_grpc_port,
use_embed=use_embed,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
try:
if weaviate_settings.url:
async_client = weaviate.use_async_with_weaviate_cloud(
cluster_url=str(weaviate_settings.url),
auth_credentials=Auth.api_key(weaviate_settings.api_key.get_secret_value()),
)
elif weaviate_settings.local_host:
kwargs = {
"port": weaviate_settings.local_port,
"grpc_port": weaviate_settings.local_grpc_port,
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
async_client = weaviate.use_async_with_local(
host=weaviate_settings.local_host,
**kwargs,
)
elif weaviate_settings.use_embed:
async_client = weaviate.use_async_with_embedded()
else:
raise NotImplementedError(
"Weaviate settings must specify either a custom client, a Weaviate Cloud instance,",
" a local Weaviate instance, or the client embedding options.",
)
except Exception as e:
raise MemoryConnectorInitializationError(f"Failed to initialize Weaviate client: {e}")
super().__init__(
data_model_type=data_model_type,
data_model_definition=data_model_definition,
collection_name=collection_name,
async_client=async_client,
managed_client=managed_client,
)
@field_validator("collection_name")
@classmethod
def collection_name_must_start_with_uppercase(cls, value: str) -> str:
"""By convention, the collection name starts with an uppercase letter.
https://weaviate.io/developers/weaviate/manage-data/collections#create-a-collection
Will change the collection name to start with an uppercase letter if it does not.
"""
if value[0].isupper():
return value
return value[0].upper() + value[1:]
@override
async def _inner_upsert(
self,
records: Sequence[Any],
**kwargs: Any,
) -> Sequence[TKey]:
assert all([isinstance(record, DataObject) for record in records]) # nosec
async with self.async_client:
try:
collection: CollectionAsync = self.async_client.collections.get(self.collection_name)
response = await collection.data.insert_many(records)
except Exception as ex:
raise MemoryConnectorException(f"Failed to upsert records: {ex}")
return [str(v) for _, v in response.uuids.items()]
@override
async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any] | None:
include_vectors: bool = kwargs.get("include_vectors", False)
named_vectors: list[str] = []
if include_vectors:
named_vectors = [
data_field.name
for data_field in self.data_model_definition.fields.values()
if isinstance(data_field, VectorStoreRecordDataField) and data_field.has_embedding
]
async with self.async_client:
try:
collection: CollectionAsync = self.async_client.collections.get(self.collection_name)
result = await collection.query.fetch_objects(
filters=Filter.any_of([Filter.by_id().equal(key) for key in keys]),
# Requires a list of named vectors if it is not empty. Otherwise, a boolean is sufficient.
include_vector=named_vectors or include_vectors,
)
return result.objects
except Exception as ex:
raise MemoryConnectorException(f"Failed to get records: {ex}")
@override
async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None:
async with self.async_client:
try:
collection: CollectionAsync = self.async_client.collections.get(self.collection_name)
await collection.data.delete_many(where=Filter.any_of([Filter.by_id().equal(key) for key in keys]))
except Exception as ex:
raise MemoryConnectorException(f"Failed to delete records: {ex}")
@override
def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]:
"""Create a data object from a record based on the data model definition."""
records_in_store_model: list[DataObject] = []
for record in records:
properties = extract_properties_from_dict_record_based_on_data_model_definition(
record, self.data_model_definition
)
# If key is None, Weaviate will generate a UUID
key = extract_key_from_dict_record_based_on_data_model_definition(record, self.data_model_definition)
vectors = extract_vectors_from_dict_record_based_on_data_model_definition(
record, self.data_model_definition
)
records_in_store_model.append(DataObject(properties=properties, uuid=key, vector=vectors))
return records_in_store_model
@override
def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]:
records_in_dict: list[dict[str, Any]] = []
for record in records:
properties = extract_properties_from_weaviate_object_based_on_data_model_definition(
record, self.data_model_definition
)
key = extract_key_from_weaviate_object_based_on_data_model_definition(record, self.data_model_definition)
vectors = extract_vectors_from_weaviate_object_based_on_data_model_definition(
record, self.data_model_definition
)
records_in_dict.append(properties | key | vectors)
return records_in_dict
@override
async def create_collection(self, **kwargs) -> None:
"""Create the collection in Weaviate.
Args:
**kwargs: Additional keyword arguments.
"""
async with self.async_client:
try:
await self.async_client.collections.create(
self.collection_name,
properties=data_model_definition_to_weaviate_properties(self.data_model_definition),
vectorizer_config=data_model_definition_to_weaviate_named_vectors(self.data_model_definition),
)
except Exception as ex:
raise MemoryConnectorException(f"Failed to create collection: {ex}")
@override
async def does_collection_exist(self, **kwargs) -> bool:
"""Check if the collection exists in Weaviate.
Args:
**kwargs: Additional keyword arguments.
Returns:
bool: Whether the collection exists.
"""
async with self.async_client:
try:
await self.async_client.collections.exists(self.collection_name)
return True
except Exception as ex:
raise MemoryConnectorException(f"Failed to check if collection exists: {ex}")
@override
async def delete_collection(self, **kwargs) -> None:
"""Delete the collection in Weaviate.
Args:
**kwargs: Additional keyword arguments.
"""
async with self.async_client:
try:
await self.async_client.collections.delete(self.collection_name)
except Exception as ex:
raise MemoryConnectorException(f"Failed to delete collection: {ex}")
@override
async def __aenter__(self) -> "WeaviateCollection":
"""Enter the context manager."""
if not await self.async_client.is_live():
try:
await self.async_client.connect()
except WeaviateConnectionError as exc:
raise MemoryConnectorConnectionException("Weaviate client cannot connect.") from exc
return self
@override
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
if self.managed_client:
await self.async_client.close()