Skip to content

Commit e8876f4

Browse files
fix: Qdrant module (#463)
This PR adds a module to spawn a [Qdrant](https://qdrant.tech) test container. --------- Co-authored-by: David Ankin <[email protected]>
1 parent 507e466 commit e8876f4

File tree

8 files changed

+338
-1
lines changed

8 files changed

+338
-1
lines changed

conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
extensions = [
3232
"sphinx.ext.autodoc",
3333
"sphinx.ext.doctest",
34+
"sphinx.ext.intersphinx",
3435
"sphinx.ext.napoleon",
3536
]
3637

@@ -156,3 +157,7 @@
156157
"Miscellaneous",
157158
),
158159
]
160+
161+
intersphinx_mapping = {
162+
"python": ("https://docs.python.org/3", None),
163+
}

index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ testcontainers-python facilitates the use of Docker containers for functional an
3636
modules/opensearch/README
3737
modules/oracle/README
3838
modules/postgres/README
39+
modules/qdrant/README
3940
modules/rabbitmq/README
4041
modules/redis/README
4142
modules/selenium/README

modules/qdrant/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.. autoclass:: testcontainers.qdrant.QdrantContainer
2+
.. title:: testcontainers.qdrant.QdrantContainer
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
3+
# not use this file except in compliance with the License. You may obtain
4+
# a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
# License for the specific language governing permissions and limitations
12+
# under the License.
13+
import os
14+
from functools import cached_property
15+
from pathlib import Path
16+
from typing import Optional
17+
18+
from testcontainers.core.config import TIMEOUT
19+
from testcontainers.core.generic import DbContainer
20+
from testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs
21+
22+
23+
class QdrantContainer(DbContainer):
24+
"""
25+
Qdrant vector database container.
26+
27+
Example:
28+
.. doctest::
29+
30+
>>> from testcontainers.qdrant import QdrantContainer
31+
32+
>>> with QdrantContainer() as qdrant:
33+
... client = qdrant.get_client()
34+
... client.get_collections()
35+
CollectionsResponse(collections=[])
36+
"""
37+
38+
QDRANT_CONFIG_FILE_PATH = "/qdrant/config/config.yaml"
39+
40+
def __init__(
41+
self,
42+
image: str = "qdrant/qdrant:v1.8.3",
43+
rest_port: int = 6333,
44+
grpc_port: int = 6334,
45+
api_key: Optional[str] = None,
46+
config_file_path: Optional[Path] = None,
47+
**kwargs,
48+
) -> None:
49+
super().__init__(image, **kwargs)
50+
self._rest_port = rest_port
51+
self._grpc_port = grpc_port
52+
self._api_key = api_key or os.getenv("QDRANT_CONTAINER_API_KEY")
53+
54+
if config_file_path:
55+
self.with_volume_mapping(host=str(config_file_path), container=QdrantContainer.QDRANT_CONFIG_FILE_PATH)
56+
57+
self.with_exposed_ports(self._rest_port, self._grpc_port)
58+
59+
def _configure(self) -> None:
60+
self.with_env("QDRANT__SERVICE__API_KEY", self._api_key)
61+
62+
@wait_container_is_ready()
63+
def _connect(self) -> None:
64+
wait_for_logs(self, ".*Actix runtime found; starting in Actix runtime.*", TIMEOUT)
65+
66+
def get_client(self, **kwargs):
67+
"""
68+
Get a `qdrant_client.QdrantClient` instance associated with the container.
69+
70+
Args:
71+
**kwargs: Additional keyword arguments to be passed to the `qdrant_client.QdrantClient` constructor.
72+
73+
Returns:
74+
QdrantClient: An instance of the `qdrant_client.QdrantClient` class.
75+
76+
"""
77+
78+
try:
79+
from qdrant_client import QdrantClient
80+
except ImportError as e:
81+
raise ImportError("To use the `get_client` method, you must install the `qdrant_client` package.") from e
82+
return QdrantClient(
83+
host=self.get_container_host_ip(),
84+
port=self.get_exposed_port(self._rest_port),
85+
grpc_port=self.get_exposed_port(self._grpc_port),
86+
api_key=self._api_key,
87+
https=False,
88+
**kwargs,
89+
)
90+
91+
def get_async_client(self, **kwargs):
92+
"""
93+
Get a `qdrant_client.AsyncQdrantClient` instance associated with the container.
94+
95+
Args:
96+
**kwargs: Additional keyword arguments to be passed to the `qdrant_client.AsyncQdrantClient` constructor.
97+
98+
Returns:
99+
QdrantClient: An instance of the `qdrant_client.AsyncQdrantClient` class.
100+
101+
"""
102+
103+
try:
104+
from qdrant_client import AsyncQdrantClient
105+
except ImportError as e:
106+
raise ImportError(
107+
"To use the `get_async_client` method, you must install the `qdrant_client` package."
108+
) from e
109+
return AsyncQdrantClient(
110+
host=self.get_container_host_ip(),
111+
port=self.get_exposed_port(self._rest_port),
112+
grpc_port=self.get_exposed_port(self._grpc_port),
113+
api_key=self._api_key,
114+
https=False,
115+
**kwargs,
116+
)
117+
118+
@cached_property
119+
def rest_host_address(self) -> str:
120+
"""
121+
Get the REST host address of the Qdrant container.
122+
123+
Returns:
124+
str: The REST host address of the Qdrant container.
125+
"""
126+
return f"{self.get_container_host_ip()}:{self.exposed_rest_port}"
127+
128+
@cached_property
129+
def grpc_host_address(self) -> str:
130+
"""
131+
Get the GRPC host address of the Qdrant container.
132+
133+
Returns:
134+
str: The GRPC host address of the Qdrant container.
135+
"""
136+
return f"{self.get_container_host_ip()}:{self.exposed_grpc_port}"
137+
138+
@cached_property
139+
def exposed_rest_port(self) -> int:
140+
"""
141+
Get the exposed REST port of the Qdrant container.
142+
143+
Returns:
144+
int: The REST port of the Qdrant container.
145+
"""
146+
return self.get_exposed_port(self._rest_port)
147+
148+
@cached_property
149+
def exposed_grpc_port(self) -> int:
150+
"""
151+
Get the exposed GRPC port of the Qdrant container.
152+
153+
Returns:
154+
int: The GRPC port of the Qdrant container.
155+
"""
156+
return self.get_exposed_port(self._grpc_port)

modules/qdrant/tests/test_config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Qdrant image configuration file for testing
2+
# Reference: https://qdrant.tech/documentation/guides/configuration/#configuration-file-example
3+
log_level: INFO
4+
5+
service:
6+
api_key: "SOME_TEST_KEY"

modules/qdrant/tests/test_qdrant.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pytest
2+
from testcontainers.qdrant import QdrantContainer
3+
import uuid
4+
from grpc import RpcError
5+
from pathlib import Path
6+
7+
import qdrant_client
8+
9+
10+
def test_docker_run_qdrant():
11+
with QdrantContainer() as qdrant:
12+
client = qdrant.get_client()
13+
collections = client.get_collections().collections
14+
assert len(collections) == 0
15+
16+
client = qdrant.get_client(prefer_grpc=True)
17+
collections = client.get_collections().collections
18+
assert len(collections) == 0
19+
20+
21+
def test_qdrant_with_api_key_http():
22+
api_key = uuid.uuid4().hex
23+
24+
with QdrantContainer(api_key=api_key) as qdrant:
25+
with pytest.raises(qdrant_client.http.exceptions.UnexpectedResponse) as e:
26+
# Construct a client without an API key
27+
qdrant_client.QdrantClient(location=f"http://{qdrant.rest_host_address}").get_collections()
28+
29+
assert "Must provide an API key" in str(e.value)
30+
31+
# Construct a client with an API key
32+
collections = (
33+
qdrant_client.QdrantClient(location=f"http://{qdrant.rest_host_address}", api_key=api_key)
34+
.get_collections()
35+
.collections
36+
)
37+
38+
assert len(collections) == 0
39+
40+
# Get an automatically configured client instance
41+
collections = qdrant.get_client().get_collections().collections
42+
43+
assert len(collections) == 0
44+
45+
46+
def test_qdrant_with_api_key_grpc():
47+
api_key = uuid.uuid4().hex
48+
49+
with QdrantContainer(api_key=api_key) as qdrant:
50+
with pytest.raises(RpcError) as e:
51+
qdrant_client.QdrantClient(
52+
url=f"http://{qdrant.grpc_host_address}",
53+
grpc_port=qdrant.exposed_grpc_port,
54+
prefer_grpc=True,
55+
).get_collections()
56+
57+
assert "Must provide an API key" in str(e.value)
58+
59+
collections = (
60+
qdrant_client.QdrantClient(
61+
url=f"http://{qdrant.grpc_host_address}",
62+
grpc_port=qdrant.exposed_grpc_port,
63+
prefer_grpc=True,
64+
api_key=api_key,
65+
)
66+
.get_collections()
67+
.collections
68+
)
69+
70+
assert len(collections) == 0
71+
72+
73+
def test_qdrant_with_config_file():
74+
config_file_path = Path(__file__).with_name("test_config.yaml")
75+
76+
with QdrantContainer(config_file_path=config_file_path) as qdrant:
77+
with pytest.raises(qdrant_client.http.exceptions.UnexpectedResponse) as e:
78+
qdrant_client.QdrantClient(location=f"http://{qdrant.rest_host_address}").get_collections()
79+
80+
assert "Must provide an API key" in str(e.value)

0 commit comments

Comments
 (0)