Skip to content

Commit 2c50ced

Browse files
Python: introducing Vector Search for Qdrant Collection (#9621)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Adds the vector search functions to the existing QdrantCollection. Currently only support VectorizedSearch. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 4cd7f07 commit 2c50ced

File tree

10 files changed

+163
-33
lines changed

10 files changed

+163
-33
lines changed

python/.cspell.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"opentelemetry",
7474
"SEMANTICKERNEL",
7575
"OTEL",
76-
"vectorizable"
76+
"vectorizable",
77+
"desync"
7778
]
7879
}

python/samples/concepts/memory/new_memory.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import argparse
4+
import asyncio
45
from collections.abc import Callable
56
from dataclasses import dataclass, field
67
from typing import Annotated
@@ -29,6 +30,8 @@
2930
vectorstoremodel,
3031
)
3132
from semantic_kernel.data.const import DistanceFunction, IndexKind
33+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
34+
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
3235

3336

3437
def get_data_model_array(index_kind: IndexKind, distance_function: DistanceFunction) -> type:
@@ -82,7 +85,7 @@ class DataModelList:
8285
collection_name = "test"
8386
# Depending on the vector database, the index kind and distance function may need to be adjusted,
8487
# since not all combinations are supported by all databases.
85-
DataModel = get_data_model_array(IndexKind.HNSW, DistanceFunction.COSINE)
88+
DataModel = get_data_model_array(IndexKind.HNSW, DistanceFunction.COSINE_SIMILARITY)
8689

8790
# A list of VectorStoreRecordCollection that can be used.
8891
# Available collections are:
@@ -144,38 +147,63 @@ class DataModelList:
144147

145148

146149
async def main(collection: str, use_azure_openai: bool, embedding_model: str):
150+
print("-" * 30)
147151
kernel = Kernel()
148152
service_id = "embedding"
149153
if use_azure_openai:
150-
kernel.add_service(AzureTextEmbedding(service_id=service_id, deployment_name=embedding_model))
154+
embedder = AzureTextEmbedding(service_id=service_id, deployment_name=embedding_model)
151155
else:
152-
kernel.add_service(OpenAITextEmbedding(service_id=service_id, ai_model_id=embedding_model))
156+
embedder = OpenAITextEmbedding(service_id=service_id, ai_model_id=embedding_model)
157+
kernel.add_service(embedder)
153158
async with collections[collection]() as record_collection:
159+
print(f"Creating {collection} collection!")
154160
await record_collection.create_collection_if_not_exists()
155161

156-
record1 = DataModel(content="My text", id="e6103c03-487f-4d7d-9c23-4723651c17f4")
157-
record2 = DataModel(content="My other text", id="09caec77-f7e1-466a-bcec-f1d51c5b15be")
162+
record1 = DataModel(content="Semantic Kernel is awesome", id="e6103c03-487f-4d7d-9c23-4723651c17f4")
163+
record2 = DataModel(
164+
content="Semantic Kernel is available in dotnet, python and Java.",
165+
id="09caec77-f7e1-466a-bcec-f1d51c5b15be",
166+
)
158167

168+
print("Adding records!")
159169
records = await VectorStoreRecordUtils(kernel).add_vector_to_records(
160170
[record1, record2], data_model_type=DataModel
161171
)
162172
keys = await record_collection.upsert_batch(records)
163-
print(f"upserted {keys=}")
164-
173+
print(f" Upserted {keys=}")
174+
print("Getting records!")
165175
results = await record_collection.get_batch([record1.id, record2.id])
166176
if results:
167177
for result in results:
168-
print(f"found {result.id=}")
169-
print(f"{result.content=}")
178+
print(f" Found id: {result.id}")
179+
print(f" Content: {result.content}")
170180
if result.vector is not None:
171-
print(f"{result.vector[:5]=}")
181+
print(f" Vector (first five): {result.vector[:5]}")
172182
else:
173-
print("not found")
183+
print("Nothing found...")
184+
if isinstance(record_collection, VectorizedSearchMixin):
185+
print("-" * 30)
186+
print("Using vectorized search, the distance function is set to cosine_similarity.")
187+
print("This means that the higher the score the more similar.")
188+
search_results = await record_collection.vectorized_search(
189+
vector=(await embedder.generate_raw_embeddings(["python"]))[0],
190+
options=VectorSearchOptions(vector_field_name="vector", include_vectors=True),
191+
)
192+
results = [record async for record in search_results.results]
193+
for result in results:
194+
print(f" Found id: {result.record.id}")
195+
print(f" Content: {result.record.content}")
196+
if result.record.vector is not None:
197+
print(f" Vector (first five): {result.record.vector[:5]}")
198+
print(f" Score: {result.score:.4f}")
199+
print("")
200+
print("-" * 30)
201+
print("Deleting collection!")
202+
await record_collection.delete_collection()
203+
print("Done!")
174204

175205

176206
if __name__ == "__main__":
177-
import asyncio
178-
179207
argparse.ArgumentParser()
180208

181209
parser = argparse.ArgumentParser()

python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async def _inner_search(
282282
]
283283
raw_results = await self.search_client.search(**search_args)
284284
return KernelSearchResults(
285-
results=self._get_vector_search_results_from_results(raw_results),
285+
results=self._get_vector_search_results_from_results(raw_results, options),
286286
total_count=await raw_results.get_count() if options.include_total_count else None,
287287
)
288288

python/semantic_kernel/connectors/memory/in_memory/in_memory_collection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def _inner_search_text(
132132
if return_records:
133133
return KernelSearchResults(
134134
results=self._get_vector_search_results_from_results(
135-
self._generate_return_list(return_records, options)
135+
self._generate_return_list(return_records, options), options
136136
),
137137
total_count=len(return_records) if options and options.include_total_count else None,
138138
)
@@ -167,7 +167,7 @@ async def _inner_search_vectorized(
167167
if sorted_records:
168168
return KernelSearchResults(
169169
results=self._get_vector_search_results_from_results(
170-
self._generate_return_list(sorted_records, options)
170+
self._generate_return_list(sorted_records, options), options
171171
),
172172
total_count=len(return_records) if options and options.include_total_count else None,
173173
)

python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py

+64-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import sys
55
from collections.abc import Mapping, Sequence
6-
from typing import Any, ClassVar, TypeVar
6+
from typing import Any, ClassVar, Generic, TypeVar
77

88
if sys.version_info >= (3, 12):
99
from typing import override # pragma: no cover
@@ -12,17 +12,22 @@
1212

1313
from pydantic import ValidationError
1414
from qdrant_client.async_qdrant_client import AsyncQdrantClient
15-
from qdrant_client.models import PointStruct, VectorParams
15+
from qdrant_client.models import FieldCondition, Filter, MatchAny, PointStruct, QueryResponse, ScoredPoint, VectorParams
1616

1717
from semantic_kernel.connectors.memory.qdrant.const import DISTANCE_FUNCTION_MAP, TYPE_MAPPER_VECTOR
1818
from semantic_kernel.connectors.memory.qdrant.utils import AsyncQdrantClientWrapper
19+
from semantic_kernel.data.kernel_search_results import KernelSearchResults
1920
from semantic_kernel.data.record_definition import VectorStoreRecordDefinition, VectorStoreRecordVectorField
20-
from semantic_kernel.data.vector_storage import VectorStoreRecordCollection
21+
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
22+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
23+
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
24+
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
2125
from semantic_kernel.exceptions import (
2226
MemoryConnectorInitializationError,
2327
VectorStoreModelValidationError,
2428
)
2529
from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException
30+
from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException
2631
from semantic_kernel.kernel_types import OneOrMany
2732
from semantic_kernel.utils.experimental_decorator import experimental_class
2833
from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent
@@ -34,7 +39,11 @@
3439

3540

3641
@experimental_class
37-
class QdrantCollection(VectorStoreRecordCollection[str | int, TModel]):
42+
class QdrantCollection(
43+
VectorSearchBase[str | int, TModel],
44+
VectorizedSearchMixin[TModel],
45+
Generic[TModel],
46+
):
3847
"""A QdrantCollection is a memory collection that uses Qdrant as the backend."""
3948

4049
qdrant_client: AsyncQdrantClient
@@ -163,6 +172,53 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None:
163172
**kwargs,
164173
)
165174

175+
@override
176+
async def _inner_search(
177+
self,
178+
options: VectorSearchOptions,
179+
search_text: str | None = None,
180+
vectorizable_text: str | None = None,
181+
vector: list[float | int] | None = None,
182+
**kwargs: Any,
183+
) -> KernelSearchResults[VectorSearchResult[TModel]]:
184+
query_vector: tuple[str, list[float | int]] | list[float | int] | None = None
185+
if vector is not None:
186+
if self.named_vectors and options.vector_field_name:
187+
query_vector = (options.vector_field_name, vector)
188+
else:
189+
query_vector = vector
190+
if query_vector is None:
191+
raise VectorSearchExecutionException("Search requires either a vector.")
192+
results = await self.qdrant_client.search(
193+
collection_name=self.collection_name,
194+
query_vector=query_vector,
195+
query_filter=self._create_filter(options),
196+
with_vectors=options.include_vectors,
197+
limit=options.top,
198+
offset=options.skip,
199+
**kwargs,
200+
)
201+
return KernelSearchResults(
202+
results=self._get_vector_search_results_from_results(results, options),
203+
total_count=len(results) if options.include_total_count else None,
204+
)
205+
206+
@override
207+
def _get_record_from_result(self, result: ScoredPoint | QueryResponse) -> Any:
208+
return result
209+
210+
@override
211+
def _get_score_from_result(self, result: ScoredPoint) -> float:
212+
return result.score
213+
214+
def _create_filter(self, options: VectorSearchOptions) -> Filter:
215+
return Filter(
216+
must=[
217+
FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value))
218+
for filter in options.filter.filters
219+
]
220+
)
221+
166222
@override
167223
def _serialize_dicts_to_store_models(
168224
self,
@@ -183,15 +239,17 @@ def _serialize_dicts_to_store_models(
183239
@override
184240
def _deserialize_store_models_to_dicts(
185241
self,
186-
records: Sequence[PointStruct],
242+
records: Sequence[PointStruct] | Sequence[ScoredPoint],
187243
**kwargs: Any,
188244
) -> Sequence[dict[str, Any]]:
189245
return [
190246
{
191247
self._key_field_name: record.id,
192248
**(record.payload if record.payload else {}),
193249
**(
194-
record.vector
250+
{}
251+
if not record.vector
252+
else record.vector
195253
if isinstance(record.vector, dict)
196254
else {self.data_model_definition.vector_field_names[0]: record.vector}
197255
),

python/semantic_kernel/data/vector_search/vector_search.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from abc import abstractmethod
5-
from collections.abc import AsyncIterable
5+
from collections.abc import AsyncIterable, Sequence
66
from typing import Any, Generic, TypeVar
77

88
from semantic_kernel.data.kernel_search_results import KernelSearchResults
@@ -11,6 +11,7 @@
1111
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
1212
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
1313
from semantic_kernel.utils.experimental_decorator import experimental_class
14+
from semantic_kernel.utils.list_handler import desync_list
1415

1516
TModel = TypeVar("TModel")
1617
TKey = TypeVar("TKey")
@@ -100,10 +101,14 @@ def _get_score_from_result(self, result: Any) -> float | None:
100101
# region: New methods
101102

102103
async def _get_vector_search_results_from_results(
103-
self, results: AsyncIterable[Any]
104+
self, results: AsyncIterable[Any] | Sequence[Any], options: VectorSearchOptions | None = None
104105
) -> AsyncIterable[VectorSearchResult[TModel]]:
106+
if isinstance(results, Sequence):
107+
results = desync_list(results)
105108
async for result in results:
106-
record = self.deserialize(self._get_record_from_result(result))
109+
record = self.deserialize(
110+
self._get_record_from_result(result), include_vectors=options.include_vectors if options else True
111+
)
107112
score = self._get_score_from_result(result)
108113
if record:
109114
# single records are always returned as single records by the deserializer

python/semantic_kernel/data/vector_storage/vector_store_record_collection.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **k
529529
530530
The input of this should come from the _deserialized_store_model_to_dict function.
531531
"""
532+
include_vectors = kwargs.get("include_vectors", True)
532533
if self.data_model_definition.from_dict:
533534
if isinstance(record, Sequence):
534535
return self.data_model_definition.from_dict(record, **kwargs)
@@ -544,24 +545,28 @@ def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **k
544545
try:
545546
if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields):
546547
return self.data_model_type.model_validate(record) # type: ignore
547-
for field in self.data_model_definition.vector_fields:
548-
if field.serialize_function:
549-
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
548+
if include_vectors:
549+
for field in self.data_model_definition.vector_fields:
550+
if field.serialize_function:
551+
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
550552
return self.data_model_type.model_validate(record) # type: ignore
551553
except Exception as exc:
552554
raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc
553555
if hasattr(self.data_model_type, "from_dict"):
554556
try:
555557
if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields):
556558
return self.data_model_type.from_dict(record) # type: ignore
557-
for field in self.data_model_definition.vector_fields:
558-
if field.serialize_function:
559-
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
559+
if include_vectors:
560+
for field in self.data_model_definition.vector_fields:
561+
if field.serialize_function:
562+
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
560563
return self.data_model_type.from_dict(record) # type: ignore
561564
except Exception as exc:
562565
raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc
563566
data_model_dict: dict[str, Any] = {}
564567
for field_name in self.data_model_definition.fields: # type: ignore
568+
if not include_vectors and field_name in self.data_model_definition.vector_field_names:
569+
continue
565570
try:
566571
value = record[field_name]
567572
if func := getattr(self.data_model_definition.fields[field_name], "deserialize_function", None):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
4+
from collections.abc import AsyncIterable, Sequence
5+
from typing import TypeVar
6+
7+
_T = TypeVar("_T")
8+
9+
10+
async def desync_list(sync_list: Sequence[_T]) -> AsyncIterable[_T]: # noqa: RUF029
11+
"""De synchronize a list of synchronous objects."""
12+
for x in sync_list:
13+
yield x

python/tests/unit/connectors/memory/in_memory/test_in_memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def test_vectorized_search_similar(collection, distance_function):
120120
await collection.upsert_batch([record1, record2])
121121
results = await collection.vectorized_search(
122122
vector=[0.9, 0.9, 0.9, 0.9, 0.9],
123-
options=VectorSearchOptions(vector_field_name="vector", include_total_count=True),
123+
options=VectorSearchOptions(vector_field_name="vector", include_total_count=True, include_vectors=True),
124124
)
125125
assert results.total_count == 2
126126
idx = 0

python/tests/unit/connectors/memory/qdrant/test_qdrant.py

+20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection
1010
from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore
1111
from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField
12+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
1213
from semantic_kernel.exceptions.memory_connector_exceptions import (
1314
MemoryConnectorException,
1415
MemoryConnectorInitializationError,
@@ -107,6 +108,17 @@ def mock_delete():
107108
yield mock_delete
108109

109110

111+
@fixture(autouse=True)
112+
def mock_search():
113+
with patch(f"{BASE_PATH}.search") as mock_search:
114+
from qdrant_client.models import ScoredPoint
115+
116+
response1 = ScoredPoint(id="id1", version=1, score=0.0, payload={"content": "content"})
117+
response2 = ScoredPoint(id="id2", version=1, score=0.0, payload={"content": "content"})
118+
mock_search.return_value = [response1, response2]
119+
yield mock_search
120+
121+
110122
def test_vector_store_defaults(vector_store):
111123
assert vector_store.qdrant_client is not None
112124
assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333"
@@ -269,3 +281,11 @@ async def test_create_index_fail(collection_to_use, request):
269281
collection.data_model_definition.fields["vector"].dimensions = None
270282
with raises(MemoryConnectorException, match="Vector field must have dimensions."):
271283
await collection.create_collection()
284+
285+
286+
@mark.asyncio
287+
async def test_search(collection):
288+
results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False))
289+
async for result in results.results:
290+
assert result.record["id"] == "id1"
291+
break

0 commit comments

Comments
 (0)