diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 7411b4efa4..c6e082fc22 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -15,6 +15,7 @@ if chromadb.__version__ < "0.4.15": raise ImportError("Please upgrade chromadb to version 0.4.15 or later.") + import chromadb.errors import chromadb.utils.embedding_functions as ef from chromadb.api.models.Collection import Collection except ImportError: @@ -90,7 +91,7 @@ def create_collection( collection = self.active_collection else: collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function) - except ValueError: + except (ValueError, chromadb.errors.ChromaError): collection = None if collection is None: return self.client.create_collection( diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 7b7992f717..6966269136 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -13,6 +13,7 @@ try: import chromadb + import chromadb.errors import sentence_transformers from autogen.agentchat.contrib.vectordb.chromadb import ChromaVectorDB @@ -32,12 +33,18 @@ def test_chromadb(): # test_delete_collection db.delete_collection(collection_name) - pytest.raises(ValueError, db.get_collection, collection_name) + pytest.raises((ValueError, chromadb.errors.ChromaError), db.get_collection, collection_name) # test more create collection collection = db.create_collection(collection_name, overwrite=False, get_or_create=False) assert collection.name == collection_name - pytest.raises(ValueError, db.create_collection, collection_name, overwrite=False, get_or_create=False) + pytest.raises( + (ValueError, chromadb.errors.ChromaError), + db.create_collection, + collection_name, + overwrite=False, + get_or_create=False, + ) collection = db.create_collection(collection_name, overwrite=True, get_or_create=False) assert collection.name == collection_name collection = db.create_collection(collection_name, overwrite=False, get_or_create=True)