Skip to content

Commit 5764c8c

Browse files
Python: Azure Cosmos DB NoSQL Vector Store & Collection implementation (#9296)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> We are implementing the Azure Cosmos DB NoSQL vector store and vector collection. ### Description Azure Cosmos DB NoSQL vector store & collection implementation. <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Eduard van Valkenburg <[email protected]>
1 parent 051a3d0 commit 5764c8c

26 files changed

+2077
-98
lines changed

.github/workflows/python-integration-tests.yml

+29-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ jobs:
100100
VERTEX_AI_GEMINI_MODEL_ID: ${{ vars.VERTEX_AI_GEMINI_MODEL_ID }}
101101
VERTEX_AI_EMBEDDING_MODEL_ID: ${{ vars.VERTEX_AI_EMBEDDING_MODEL_ID }}
102102
REDIS_CONNECTION_STRING: ${{ vars.REDIS_CONNECTION_STRING }}
103+
AZURE_COSMOS_DB_NO_SQL_URL: ${{ vars.AZURE_COSMOS_DB_NO_SQL_URL }}
104+
AZURE_COSMOS_DB_NO_SQL_KEY: ${{ secrets.AZURE_COSMOS_DB_NO_SQL_KEY }}
103105
steps:
104106
- uses: actions/checkout@v4
105107
- name: Set up uv
@@ -150,6 +152,12 @@ jobs:
150152
run: docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest
151153
- name: Setup Weaviate docker deployment
152154
run: docker run -d -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.26.6
155+
- name: Start Azure Cosmos DB emulator
156+
if: matrix.os == 'windows-latest'
157+
run: |
158+
Write-Host "Launching Cosmos DB Emulator"
159+
Import-Module "$env:ProgramFiles\Azure Cosmos DB Emulator\PSModules\Microsoft.Azure.CosmosDB.Emulator"
160+
Start-CosmosDbEmulator
153161
- name: Azure CLI Login
154162
if: github.event_name != 'pull_request'
155163
uses: azure/login@v2
@@ -159,31 +167,37 @@ jobs:
159167
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
160168
- name: Run Integration Tests - Completions
161169
id: run_tests_completions
170+
timeout-minutes: 10
162171
shell: bash
163172
run: |
164173
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/completions -v --junitxml=pytest-completions.xml
165174
- name: Run Integration Tests - Embeddings
166175
id: run_tests_embeddings
176+
timeout-minutes: 5
167177
shell: bash
168178
run: |
169179
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/embeddings -v --junitxml=pytest-embeddings.xml
170180
- name: Run Integration Tests - Memory
171181
id: run_tests_memory
182+
timeout-minutes: 5
172183
shell: bash
173184
run: |
174185
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/memory -v --junitxml=pytest-memory.xml
175186
- name: Run Integration Tests - Cross Language
176187
id: run_tests_cross_language
188+
timeout-minutes: 5
177189
shell: bash
178190
run: |
179191
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/cross_language -v --junitxml=pytest-cross.xml
180192
- name: Run Integration Tests - Planning
181193
id: run_tests_planning
194+
timeout-minutes: 5
182195
shell: bash
183196
run: |
184197
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/planning -v --junitxml=pytest-planning.xml
185198
- name: Run Integration Tests - Samples
186199
id: run_tests_samples
200+
timeout-minutes: 5
187201
shell: bash
188202
run: |
189203
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/samples -v --junitxml=pytest-samples.xml
@@ -255,6 +269,8 @@ jobs:
255269
VERTEX_AI_GEMINI_MODEL_ID: ${{ vars.VERTEX_AI_GEMINI_MODEL_ID }}
256270
VERTEX_AI_EMBEDDING_MODEL_ID: ${{ vars.VERTEX_AI_EMBEDDING_MODEL_ID }}
257271
REDIS_CONNECTION_STRING: ${{ vars.REDIS_CONNECTION_STRING }}
272+
AZURE_COSMOS_DB_NO_SQL_URL: ${{ vars.AZURE_COSMOS_DB_NO_SQL_URL }}
273+
AZURE_COSMOS_DB_NO_SQL_KEY: ${{ secrets.AZURE_COSMOS_DB_NO_SQL_KEY }}
258274
steps:
259275
- uses: actions/checkout@v4
260276
- name: Set up uv
@@ -305,6 +321,12 @@ jobs:
305321
run: docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest
306322
- name: Setup Weaviate docker deployment
307323
run: docker run -d -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.26.6
324+
- name: Start Azure Cosmos DB emulator
325+
if: matrix.os == 'windows-latest'
326+
run: |
327+
Write-Host "Launching Cosmos DB Emulator"
328+
Import-Module "$env:ProgramFiles\Azure Cosmos DB Emulator\PSModules\Microsoft.Azure.CosmosDB.Emulator"
329+
Start-CosmosDbEmulator
308330
- name: Azure CLI Login
309331
if: github.event_name != 'pull_request'
310332
uses: azure/login@v2
@@ -314,31 +336,37 @@ jobs:
314336
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
315337
- name: Run Integration Tests - Completions
316338
id: run_tests_completions
339+
timeout-minutes: 10
317340
shell: bash
318341
run: |
319342
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/completions -v --junitxml=pytest-completions.xml
320343
- name: Run Integration Tests - Embeddings
321344
id: run_tests_embeddings
345+
timeout-minutes: 5
322346
shell: bash
323347
run: |
324348
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/embeddings -v --junitxml=pytest-embeddings.xml
325349
- name: Run Integration Tests - Memory
326350
id: run_tests_memory
351+
timeout-minutes: 5
327352
shell: bash
328353
run: |
329354
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/memory -v --junitxml=pytest-memory.xml
330355
- name: Run Integration Tests - Cross Language
331356
id: run_tests_cross_language
357+
timeout-minutes: 5
332358
shell: bash
333359
run: |
334360
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/cross_language -v --junitxml=pytest-cross.xml
335361
- name: Run Integration Tests - Planning
336362
id: run_tests_planning
363+
timeout-minutes: 5
337364
shell: bash
338365
run: |
339366
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/integration/planning -v --junitxml=pytest-planning.xml
340367
- name: Run Integration Tests - Samples
341368
id: run_tests_samples
369+
timeout-minutes: 5
342370
shell: bash
343371
run: |
344372
uv run pytest -n logical --dist loadfile --dist worksteal ./tests/samples -v --junitxml=pytest-samples.xml
@@ -418,4 +446,4 @@ jobs:
418446
dry_run: ${{ env.run_type != 'Daily' && env.run_type != 'Manual'}}
419447
job: ${{ toJson(job) }}
420448
steps: ${{ toJson(steps) }}
421-
overwrite: "{title: ` ${{ env.run_type }}: ${{ env.date }} `, text: ` ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}`}"
449+
overwrite: "{title: ` ${{ env.run_type }}: ${{ env.date }} `, text: ` ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}`}"

python/.cspell.json

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"mongocluster",
4747
"ndarray",
4848
"nopep",
49+
"NOSQL",
4950
"ollama",
5051
"onyourdatatest",
5152
"OPENAI",

python/samples/concepts/memory/new_memory.py

+97-74
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from semantic_kernel.connectors.ai.open_ai import OpenAIEmbeddingPromptExecutionSettings, OpenAITextEmbedding
1313
from semantic_kernel.connectors.ai.open_ai.services.azure_text_embedding import AzureTextEmbedding
1414
from semantic_kernel.connectors.memory.azure_ai_search import AzureAISearchCollection
15+
from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_collection import (
16+
AzureCosmosDBNoSQLCollection,
17+
)
1518
from semantic_kernel.connectors.memory.in_memory import InMemoryVectorCollection
1619
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1720
from semantic_kernel.connectors.memory.qdrant import QdrantCollection
@@ -25,55 +28,64 @@
2528
VectorStoreRecordVectorField,
2629
vectorstoremodel,
2730
)
28-
29-
30-
@vectorstoremodel
31-
@dataclass
32-
class MyDataModelArray:
33-
vector: Annotated[
34-
np.ndarray | None,
35-
VectorStoreRecordVectorField(
36-
embedding_settings={"embedding": OpenAIEmbeddingPromptExecutionSettings(dimensions=1536)},
37-
index_kind="hnsw",
38-
dimensions=1536,
39-
distance_function="cosine_similarity",
40-
property_type="float",
41-
serialize_function=np.ndarray.tolist,
42-
deserialize_function=np.array,
43-
),
44-
] = None
45-
other: str | None = None
46-
id: Annotated[str, VectorStoreRecordKeyField()] = field(default_factory=lambda: str(uuid4()))
47-
content: Annotated[
48-
str, VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector", property_type="str")
49-
] = "content1"
50-
51-
52-
@vectorstoremodel
53-
@dataclass
54-
class MyDataModelList:
55-
vector: Annotated[
56-
list[float] | None,
57-
VectorStoreRecordVectorField(
58-
embedding_settings={"embedding": OpenAIEmbeddingPromptExecutionSettings(dimensions=1536)},
59-
index_kind="hnsw",
60-
dimensions=1536,
61-
distance_function="cosine_similarity",
62-
property_type="float",
63-
),
64-
] = None
65-
other: str | None = None
66-
id: Annotated[str, VectorStoreRecordKeyField()] = field(default_factory=lambda: str(uuid4()))
67-
content: Annotated[
68-
str, VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector", property_type="str")
69-
] = "content1"
31+
from semantic_kernel.data.const import DistanceFunction, IndexKind
32+
33+
34+
def get_data_model_array(index_kind: IndexKind, distance_function: DistanceFunction) -> type:
35+
@vectorstoremodel
36+
@dataclass
37+
class DataModelArray:
38+
vector: Annotated[
39+
np.ndarray | None,
40+
VectorStoreRecordVectorField(
41+
embedding_settings={"embedding": OpenAIEmbeddingPromptExecutionSettings(dimensions=1536)},
42+
index_kind=index_kind,
43+
dimensions=1536,
44+
distance_function=distance_function,
45+
property_type="float",
46+
serialize_function=np.ndarray.tolist,
47+
deserialize_function=np.array,
48+
),
49+
] = None
50+
other: str | None = None
51+
id: Annotated[str, VectorStoreRecordKeyField()] = field(default_factory=lambda: str(uuid4()))
52+
content: Annotated[
53+
str, VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector", property_type="str")
54+
] = "content1"
55+
56+
return DataModelArray
57+
58+
59+
def get_data_model_list(index_kind: IndexKind, distance_function: DistanceFunction) -> type:
60+
@vectorstoremodel
61+
@dataclass
62+
class DataModelList:
63+
vector: Annotated[
64+
list[float] | None,
65+
VectorStoreRecordVectorField(
66+
embedding_settings={"embedding": OpenAIEmbeddingPromptExecutionSettings(dimensions=1536)},
67+
index_kind=index_kind,
68+
dimensions=1536,
69+
distance_function=distance_function,
70+
property_type="float",
71+
),
72+
] = None
73+
other: str | None = None
74+
id: Annotated[str, VectorStoreRecordKeyField()] = field(default_factory=lambda: str(uuid4()))
75+
content: Annotated[
76+
str, VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector", property_type="str")
77+
] = "content1"
78+
79+
return DataModelList
7080

7181

7282
collection_name = "test"
73-
MyDataModel = MyDataModelArray
83+
# Depending on the vector database, the index kind and distance function may need to be adjusted,
84+
# since not all combinations are supported by all databases.
85+
DataModel = get_data_model_array(IndexKind.HNSW, DistanceFunction.COSINE)
7486

7587
# A list of VectorStoreRecordCollection that can be used.
76-
# Available stores are:
88+
# Available collections are:
7789
# - ai_search: Azure AI Search
7890
# - postgres: PostgreSQL
7991
# - redis_json: Redis JSON
@@ -83,63 +95,74 @@ class MyDataModelList:
8395
# - weaviate: Weaviate
8496
# Please either configure the weaviate settings via environment variables or provide them through the constructor.
8597
# Note that embed mode is not supported on Windows: https://github.com/weaviate/weaviate/issues/3315
86-
#
87-
# This is represented as a mapping from the store name to a
88-
# function which returns the store.
89-
# Using a function allows for lazy initialization of the store,
90-
# so that settings for unused stores do not cause validation errors.
91-
stores: dict[str, Callable[[], VectorStoreRecordCollection]] = {
92-
"ai_search": lambda: AzureAISearchCollection[MyDataModel](
93-
data_model_type=MyDataModel,
98+
# - azure_cosmos_nosql: Azure Cosmos NoSQL
99+
# https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/how-to-create-account?tabs=azure-portal
100+
# Please see the link above to learn how to set up an Azure Cosmos NoSQL account.
101+
# https://learn.microsoft.com/en-us/azure/cosmos-db/how-to-develop-emulator?tabs=windows%2Cpython&pivots=api-nosql
102+
# Please see the link above to learn how to set up the Azure Cosmos NoSQL emulator on your machine.
103+
# For this sample to work with Azure Cosmos NoSQL, please adjust the index_kind of the data model to QUANTIZED_FLAT.
104+
# This is represented as a mapping from the collection name to a
105+
# function which returns the collection.
106+
# Using a function allows for lazy initialization of the collection,
107+
# so that settings for unused collections do not cause validation errors.
108+
collections: dict[str, Callable[[], VectorStoreRecordCollection]] = {
109+
"ai_search": lambda: AzureAISearchCollection[DataModel](
110+
data_model_type=DataModel,
94111
),
95-
"postgres": lambda: PostgresCollection[str, MyDataModel](
96-
data_model_type=MyDataModel,
112+
"postgres": lambda: PostgresCollection[str, DataModel](
113+
data_model_type=DataModel,
97114
collection_name=collection_name,
98115
),
99-
"redis_json": lambda: RedisJsonCollection[MyDataModel](
100-
data_model_type=MyDataModel,
116+
"redis_json": lambda: RedisJsonCollection[DataModel](
117+
data_model_type=DataModel,
101118
collection_name=collection_name,
102119
prefix_collection_name_to_key_names=True,
103120
),
104-
"redis_hashset": lambda: RedisHashsetCollection[MyDataModel](
105-
data_model_type=MyDataModel,
121+
"redis_hashset": lambda: RedisHashsetCollection[DataModel](
122+
data_model_type=DataModel,
106123
collection_name=collection_name,
107124
prefix_collection_name_to_key_names=True,
108125
),
109-
"qdrant": lambda: QdrantCollection[MyDataModel](
110-
data_model_type=MyDataModel, collection_name=collection_name, prefer_grpc=True, named_vectors=False
126+
"qdrant": lambda: QdrantCollection[DataModel](
127+
data_model_type=DataModel, collection_name=collection_name, prefer_grpc=True, named_vectors=False
128+
),
129+
"in_memory": lambda: InMemoryVectorCollection[DataModel](
130+
data_model_type=DataModel,
131+
collection_name=collection_name,
111132
),
112-
"in_memory": lambda: InMemoryVectorCollection[MyDataModel](
113-
data_model_type=MyDataModel,
133+
"weaviate": lambda: WeaviateCollection[DataModel](
134+
data_model_type=DataModel,
114135
collection_name=collection_name,
115136
),
116-
"weaviate": lambda: WeaviateCollection[MyDataModel](
117-
data_model_type=MyDataModel,
137+
"azure_cosmos_nosql": lambda: AzureCosmosDBNoSQLCollection(
138+
data_model_type=DataModel,
139+
database_name="sample_database",
118140
collection_name=collection_name,
141+
create_database=True,
119142
),
120143
}
121144

122145

123-
async def main(store: str, use_azure_openai: bool, embedding_model: str):
146+
async def main(collection: str, use_azure_openai: bool, embedding_model: str):
124147
kernel = Kernel()
125148
service_id = "embedding"
126149
if use_azure_openai:
127150
kernel.add_service(AzureTextEmbedding(service_id=service_id, deployment_name=embedding_model))
128151
else:
129152
kernel.add_service(OpenAITextEmbedding(service_id=service_id, ai_model_id=embedding_model))
130-
async with stores[store]() as record_store:
131-
await record_store.create_collection_if_not_exists()
153+
async with collections[collection]() as record_collection:
154+
await record_collection.create_collection_if_not_exists()
132155

133-
record1 = MyDataModel(content="My text", id="e6103c03-487f-4d7d-9c23-4723651c17f4")
134-
record2 = MyDataModel(content="My other text", id="09caec77-f7e1-466a-bcec-f1d51c5b15be")
156+
record1 = DataModel(content="My text", id="e6103c03-487f-4d7d-9c23-4723651c17f4")
157+
record2 = DataModel(content="My other text", id="09caec77-f7e1-466a-bcec-f1d51c5b15be")
135158

136159
records = await VectorStoreRecordUtils(kernel).add_vector_to_records(
137-
[record1, record2], data_model_type=MyDataModel
160+
[record1, record2], data_model_type=DataModel
138161
)
139-
keys = await record_store.upsert_batch(records)
162+
keys = await record_collection.upsert_batch(records)
140163
print(f"upserted {keys=}")
141164

142-
results = await record_store.get_batch([record1.id, record2.id])
165+
results = await record_collection.get_batch([record1.id, record2.id])
143166
if results:
144167
for result in results:
145168
print(f"found {result.id=}")
@@ -156,7 +179,7 @@ async def main(store: str, use_azure_openai: bool, embedding_model: str):
156179
argparse.ArgumentParser()
157180

158181
parser = argparse.ArgumentParser()
159-
parser.add_argument("--store", default="in_memory", choices=stores.keys(), help="What store to use.")
182+
parser.add_argument("--collection", default="in_memory", choices=collections.keys(), help="What collection to use.")
160183
# Option of whether to use OpenAI or Azure OpenAI.
161184
parser.add_argument("--use-azure-openai", action="store_true", help="Use Azure OpenAI instead of OpenAI.")
162185
# Model
@@ -165,4 +188,4 @@ async def main(store: str, use_azure_openai: bool, embedding_model: str):
165188
)
166189
args = parser.parse_args()
167190

168-
asyncio.run(main(store=args.store, use_azure_openai=args.use_azure_openai, embedding_model=args.model))
191+
asyncio.run(main(collection=args.collection, use_azure_openai=args.use_azure_openai, embedding_model=args.model))

python/semantic_kernel/connectors/memory/azure_cosmos_db/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)