Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Qdrant module #463

Merged
merged 17 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions INDEX.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ testcontainers-python facilitates the use of Docker containers for functional an
modules/opensearch/README
modules/oracle/README
modules/postgres/README
modules/qdrant/README
modules/rabbitmq/README
modules/redis/README
modules/selenium/README
Expand Down
5 changes: 5 additions & 0 deletions conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
]

Expand Down Expand Up @@ -156,3 +157,7 @@
"Miscellaneous",
),
]

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
}
2 changes: 2 additions & 0 deletions modules/qdrant/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. autoclass:: testcontainers.qdrant.QdrantContainer
.. title:: testcontainers.qdrant.QdrantContainer
145 changes: 145 additions & 0 deletions modules/qdrant/testcontainers/qdrant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
from functools import cached_property
from pathlib import Path
from typing import Optional

from qdrant_client import AsyncQdrantClient, QdrantClient
Anush008 marked this conversation as resolved.
Show resolved Hide resolved

from testcontainers.core.config import TIMEOUT
from testcontainers.core.generic import DbContainer
from testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs


class QdrantContainer(DbContainer):
"""
Qdrant vector database container.

Example:
.. doctest::

>>> from testcontainers.qdrant import QdrantContainer

>>> with QdrantContainer() as qdrant:
... client = qdrant.get_client()
... client.get_collections()
"""

QDRANT_CONFIG_FILE_PATH = "/qdrant/config/config.yaml"

def __init__(
self,
image: str = "qdrant/qdrant:v1.8.1",
rest_port: int = 6333,
grpc_port: int = 6334,
api_key: Optional[str] = None,
config_file_path: Optional[Path] = None,
**kwargs,
) -> None:
super().__init__(image, **kwargs)
self._rest_port = rest_port
self._grpc_port = grpc_port
self._api_key = api_key or os.getenv("QDRANT_CONTAINER_API_KEY")

if config_file_path:
self.with_volume_mapping(host=str(config_file_path), container=QdrantContainer.QDRANT_CONFIG_FILE_PATH)

self.with_exposed_ports(self._rest_port, self._grpc_port)

def _configure(self) -> None:
self.with_env("QDRANT__SERVICE__API_KEY", self._api_key)

@wait_container_is_ready()
def _connect(self) -> None:
wait_for_logs(self, ".*Actix runtime found; starting in Actix runtime.*", TIMEOUT)

def get_client(self, **kwargs) -> "QdrantClient":
"""
Get a `qdrant_client.QdrantClient` instance associated with the container.

Args:
**kwargs: Additional keyword arguments to be passed to the `qdrant_client.QdrantClient` constructor.

Returns:
QdrantClient: An instance of the `qdrant_client.QdrantClient` class.

"""
return QdrantClient(
host=self.get_container_host_ip(),
port=self.get_exposed_port(self._rest_port),
grpc_port=self.get_exposed_port(self._grpc_port),
api_key=self._api_key,
https=False,
**kwargs,
)

def get_async_client(self, **kwargs) -> "AsyncQdrantClient":
"""
Get a `qdrant_client.AsyncQdrantClient` instance associated with the container.

Args:
**kwargs: Additional keyword arguments to be passed to the `qdrant_client.AsyncQdrantClient` constructor.

Returns:
QdrantClient: An instance of the `qdrant_client.AsyncQdrantClient` class.

"""
return AsyncQdrantClient(
host=self.get_container_host_ip(),
port=self.get_exposed_port(self._rest_port),
grpc_port=self.get_exposed_port(self._grpc_port),
api_key=self._api_key,
https=False,
**kwargs,
)

@cached_property
def rest_host_address(self) -> str:
"""
Get the REST host address of the Qdrant container.

Returns:
str: The REST host address of the Qdrant container.
"""
return f"{self.get_container_host_ip()}:{self.exposed_rest_port}"

@cached_property
def grpc_host_address(self) -> str:
"""
Get the GRPC host address of the Qdrant container.

Returns:
str: The GRPC host address of the Qdrant container.
"""
return f"{self.get_container_host_ip()}:{self.exposed_grpc_port}"

@cached_property
def exposed_rest_port(self) -> int:
"""
Get the exposed REST port of the Qdrant container.

Returns:
int: The REST port of the Qdrant container.
"""
return self.get_exposed_port(self._rest_port)

@cached_property
def exposed_grpc_port(self) -> int:
"""
Get the exposed GRPC port of the Qdrant container.

Returns:
int: The GRPC port of the Qdrant container.
"""
return self.get_exposed_port(self._grpc_port)
6 changes: 6 additions & 0 deletions modules/qdrant/tests/test_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Qdrant image configuration file for testing
# Reference: https://qdrant.tech/documentation/guides/configuration/#configuration-file-example
log_level: INFO

service:
api_key: "SOME_TEST_KEY"
78 changes: 78 additions & 0 deletions modules/qdrant/tests/test_qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from testcontainers.qdrant import QdrantContainer
import uuid
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from grpc import RpcError
from pathlib import Path


def test_docker_run_qdrant():
with QdrantContainer() as qdrant:
client = qdrant.get_client()
collections = client.get_collections().collections
assert len(collections) == 0

client = qdrant.get_client(prefer_grpc=True)
collections = client.get_collections().collections
assert len(collections) == 0


def test_qdrant_with_api_key_http():
api_key = uuid.uuid4().hex

with QdrantContainer(api_key=api_key) as qdrant:
with pytest.raises(UnexpectedResponse) as e:
# Construct a client without an API key
QdrantClient(location=f"http://{qdrant.rest_host_address}").get_collections()

assert "Invalid api-key" in str(e.value)

# Construct a client with an API key
collections = (
QdrantClient(location=f"http://{qdrant.rest_host_address}", api_key=api_key).get_collections().collections
)

assert len(collections) == 0

# Get an automatically configured client instance
collections = qdrant.get_client().get_collections().collections

assert len(collections) == 0


def test_qdrant_with_api_key_grpc():
api_key = uuid.uuid4().hex

with QdrantContainer(api_key=api_key) as qdrant:
with pytest.raises(RpcError) as e:
QdrantClient(
url=f"http://{qdrant.grpc_host_address}",
grpc_port=qdrant.exposed_grpc_port,
prefer_grpc=True,
).get_collections()

assert "Invalid api-key" in str(e.value)

collections = (
QdrantClient(
url=f"http://{qdrant.grpc_host_address}",
grpc_port=qdrant.exposed_grpc_port,
prefer_grpc=True,
api_key=api_key,
)
.get_collections()
.collections
)

assert len(collections) == 0


def test_qdrant_with_config_file():
config_file_path = Path(__file__).with_name("test_config.yaml")

with QdrantContainer(config_file_path=config_file_path) as qdrant:
with pytest.raises(UnexpectedResponse) as e:
QdrantClient(location=f"http://{qdrant.rest_host_address}").get_collections()

assert "Invalid api-key" in str(e.value)
Loading