Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e32efcf

Browse files
committedApr 11, 2025·
cosmos db entra id support and fixes
Signed-off-by: Theo van Kraay <[email protected]>
1 parent 5864255 commit e32efcf

File tree

8 files changed

+421
-37
lines changed

8 files changed

+421
-37
lines changed
 

‎auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
<version>${project.parent.version}</version>
4646
<optional>true</optional>
4747
</dependency>
48+
<dependency>
49+
<groupId>com.azure</groupId>
50+
<artifactId>azure-identity</artifactId>
51+
<version>1.15.4</version> <!-- or the latest version -->
52+
</dependency>
4853
<dependency>
4954
<groupId>org.springframework.boot</groupId>
5055
<artifactId>spring-boot-starter</artifactId>

‎auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfiguration.java

+25-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.azure.cosmos.CosmosAsyncClient;
2020
import com.azure.cosmos.CosmosClientBuilder;
21+
import com.azure.identity.DefaultAzureCredentialBuilder;
2122
import io.micrometer.observation.ObservationRegistry;
2223

2324
import org.springframework.ai.embedding.BatchingStrategy;
@@ -33,6 +34,7 @@
3334
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
3435
import org.springframework.boot.context.properties.EnableConfigurationProperties;
3536
import org.springframework.context.annotation.Bean;
37+
3638
import java.util.List;
3739

3840
/**
@@ -43,19 +45,37 @@
4345
* @author Soby Chacko
4446
* @since 1.0.0
4547
*/
48+
4649
@AutoConfiguration
4750
@ConditionalOnClass({ CosmosDBVectorStore.class, EmbeddingModel.class, CosmosAsyncClient.class })
4851
@EnableConfigurationProperties(CosmosDBVectorStoreProperties.class)
4952
@ConditionalOnProperty(name = SpringAIVectorStoreTypes.TYPE, havingValue = SpringAIVectorStoreTypes.AZURE_COSMOS_DB,
5053
matchIfMissing = true)
5154
public class CosmosDBVectorStoreAutoConfiguration {
5255

56+
private final String agentSuffix = "SpringAI-CDBNoSQL-VectorStore";
57+
5358
@Bean
5459
public CosmosAsyncClient cosmosClient(CosmosDBVectorStoreProperties properties) {
55-
return new CosmosClientBuilder().endpoint(properties.getEndpoint())
56-
.userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
57-
.key(properties.getKey())
58-
.gatewayMode()
60+
String mode = properties.getConnectionMode();
61+
if (mode == null) {
62+
properties.setConnectionMode("gateway");
63+
}
64+
else if (!mode.equals("direct") && !mode.equals("gateway")) {
65+
throw new IllegalArgumentException("Connection mode must be either 'direct' or 'gateway'");
66+
}
67+
68+
CosmosClientBuilder builder = new CosmosClientBuilder().endpoint(properties.getEndpoint())
69+
.userAgentSuffix(agentSuffix);
70+
71+
if (properties.getKey() == null || properties.getKey().isEmpty()) {
72+
builder.credential(new DefaultAzureCredentialBuilder().build());
73+
}
74+
else {
75+
builder.key(properties.getKey());
76+
}
77+
78+
return ("direct".equals(properties.getConnectionMode()) ? builder.directMode() : builder.gatewayMode())
5979
.buildAsyncClient();
6080
}
6181

@@ -75,12 +95,11 @@ public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRe
7595
return CosmosDBVectorStore.builder(cosmosAsyncClient, embeddingModel)
7696
.databaseName(properties.getDatabaseName())
7797
.containerName(properties.getContainerName())
78-
.metadataFields(List.of(properties.getMetadataFields()))
98+
.metadataFields(properties.getMetadataFieldList())
7999
.vectorStoreThroughput(properties.getVectorStoreThroughput())
80100
.vectorDimensions(properties.getVectorDimensions())
81101
.partitionKeyPath(properties.getPartitionKeyPath())
82102
.build();
83-
84103
}
85104

86105
}

‎auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/main/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreProperties.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
2020
import org.springframework.boot.context.properties.ConfigurationProperties;
2121

22+
import java.util.Arrays;
23+
import java.util.List;
24+
2225
/**
2326
* Configuration properties for CosmosDB Vector Store.
2427
*
2528
* @author Theo van Kraay
2629
* @since 1.0.0
2730
*/
28-
2931
@ConfigurationProperties(CosmosDBVectorStoreProperties.CONFIG_PREFIX)
3032
public class CosmosDBVectorStoreProperties extends CommonVectorStoreProperties {
3133

@@ -47,6 +49,8 @@ public class CosmosDBVectorStoreProperties extends CommonVectorStoreProperties {
4749

4850
private String key;
4951

52+
private String connectionMode;
53+
5054
public int getVectorStoreThroughput() {
5155
return this.vectorStoreThroughput;
5256
}
@@ -63,6 +67,12 @@ public void setMetadataFields(String metadataFields) {
6367
this.metadataFields = metadataFields;
6468
}
6569

70+
public List<String> getMetadataFieldList() {
71+
return this.metadataFields != null
72+
? Arrays.stream(this.metadataFields.split(",")).map(String::trim).filter(s -> !s.isEmpty()).toList()
73+
: List.of();
74+
}
75+
6676
public String getEndpoint() {
6777
return this.endpoint;
6878
}
@@ -79,6 +89,14 @@ public void setKey(String key) {
7989
this.key = key;
8090
}
8191

92+
public void setConnectionMode(String connectionMode) {
93+
this.connectionMode = connectionMode;
94+
}
95+
96+
public String getConnectionMode() {
97+
return this.connectionMode;
98+
}
99+
82100
public String getDatabaseName() {
83101
return this.databaseName;
84102
}

‎auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/src/test/java/org/springframework/ai/vectorstore/cosmosdb/autoconfigure/CosmosDBVectorStoreAutoConfigurationIT.java

+28-14
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,35 @@
4444
* @author Theo van Kraay
4545
* @since 1.0.0
4646
*/
47-
4847
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
4948
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+")
5049
public class CosmosDBVectorStoreAutoConfigurationIT {
5150

52-
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
53-
.withConfiguration(AutoConfigurations.of(CosmosDBVectorStoreAutoConfiguration.class))
54-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.databaseName=test-database")
55-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.containerName=test-container")
56-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.partitionKeyPath=/id")
57-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.metadataFields=country,year,city")
58-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorStoreThroughput=1000")
59-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorDimensions=384")
60-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.endpoint=" + System.getenv("AZURE_COSMOSDB_ENDPOINT"))
61-
.withPropertyValues("spring.ai.vectorstore.cosmosdb.key=" + System.getenv("AZURE_COSMOSDB_KEY"))
62-
.withUserConfiguration(Config.class);
51+
private final ApplicationContextRunner contextRunner;
52+
53+
public CosmosDBVectorStoreAutoConfigurationIT() {
54+
String endpoint = System.getenv("AZURE_COSMOSDB_ENDPOINT");
55+
String key = System.getenv("AZURE_COSMOSDB_KEY");
56+
57+
ApplicationContextRunner contextRunner = new ApplicationContextRunner()
58+
.withConfiguration(AutoConfigurations.of(CosmosDBVectorStoreAutoConfiguration.class))
59+
.withPropertyValues("spring.ai.vectorstore.cosmosdb.databaseName=test-database")
60+
.withPropertyValues("spring.ai.vectorstore.cosmosdb.containerName=test-container")
61+
.withPropertyValues("spring.ai.vectorstore.cosmosdb.partitionKeyPath=/id")
62+
.withPropertyValues("spring.ai.vectorstore.cosmosdb.metadataFields=country,year,city")
63+
.withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorStoreThroughput=1000")
64+
.withPropertyValues("spring.ai.vectorstore.cosmosdb.vectorDimensions=384");
65+
66+
if (endpoint != null && !"null".equalsIgnoreCase(endpoint)) {
67+
contextRunner = contextRunner.withPropertyValues("spring.ai.vectorstore.cosmosdb.endpoint=" + endpoint);
68+
}
69+
70+
if (key != null && !"null".equalsIgnoreCase(key)) {
71+
contextRunner = contextRunner.withPropertyValues("spring.ai.vectorstore.cosmosdb.key=" + key);
72+
}
73+
74+
this.contextRunner = contextRunner.withUserConfiguration(Config.class);
75+
}
6376

6477
private VectorStore vectorStore;
6578

@@ -124,14 +137,15 @@ void testSimilaritySearchWithFilter() {
124137
metadata4.put("country", "US");
125138
metadata4.put("year", 2020);
126139
metadata4.put("city", "Sofia");
127-
128140
Document document1 = new Document("1", "A document about the UK", metadata1);
129141
Document document2 = new Document("2", "A document about the Netherlands", metadata2);
130142
Document document3 = new Document("3", "A document about the US", metadata3);
131143
Document document4 = new Document("4", "A document about the US", metadata4);
132144

133145
this.vectorStore.add(List.of(document1, document2, document3, document4));
146+
134147
FilterExpressionBuilder b = new FilterExpressionBuilder();
148+
135149
List<Document> results = this.vectorStore.similaritySearch(SearchRequest.builder()
136150
.query("The World")
137151
.topK(10)
@@ -190,7 +204,7 @@ public void autoConfigurationEnabledByDefault() {
190204

191205
@Test
192206
public void autoConfigurationEnabledWhenTypeIsAzureCosmosDB() {
193-
this.contextRunner.withPropertyValues("spring.ai.vectorstore.type=azure-cosmmos-db").run(context -> {
207+
this.contextRunner.withPropertyValues("spring.ai.vectorstore.type=azure-cosmos-db").run(context -> {
194208
assertThat(context.getBeansOfType(CosmosDBVectorStoreProperties.class)).isNotEmpty();
195209
assertThat(context.getBeansOfType(VectorStore.class)).isNotEmpty();
196210
assertThat(context.getBean(VectorStore.class)).isInstanceOf(CosmosDBVectorStore.class);

‎vector-stores/spring-ai-azure-cosmos-db-store/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@
4747
<artifactId>azure-spring-data-cosmos</artifactId>
4848
<version>${azure-cosmos.version}</version>
4949
</dependency>
50+
<dependency>
51+
<groupId>com.azure</groupId>
52+
<artifactId>azure-identity</artifactId>
53+
<version>1.15.4</version> <!-- or the latest version -->
54+
</dependency>
5055
<dependency>
5156
<groupId>org.springframework.ai</groupId>
5257
<artifactId>spring-ai-vector-store</artifactId>

‎vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java

+108-14
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.Collections;
21+
import java.util.HashMap;
22+
import java.util.Iterator;
2123
import java.util.List;
24+
import java.util.Map;
2225
import java.util.Optional;
2326
import java.util.stream.Collectors;
2427
import java.util.stream.IntStream;
@@ -30,6 +33,7 @@
3033
import com.azure.cosmos.models.CosmosBulkOperations;
3134
import com.azure.cosmos.models.CosmosContainerProperties;
3235
import com.azure.cosmos.models.CosmosItemOperation;
36+
import com.azure.cosmos.models.CosmosItemResponse;
3337
import com.azure.cosmos.models.CosmosQueryRequestOptions;
3438
import com.azure.cosmos.models.CosmosVectorDataType;
3539
import com.azure.cosmos.models.CosmosVectorDistanceFunction;
@@ -38,6 +42,7 @@
3842
import com.azure.cosmos.models.CosmosVectorIndexSpec;
3943
import com.azure.cosmos.models.CosmosVectorIndexType;
4044
import com.azure.cosmos.models.ExcludedPath;
45+
import com.azure.cosmos.models.FeedResponse;
4146
import com.azure.cosmos.models.IncludedPath;
4247
import com.azure.cosmos.models.IndexingMode;
4348
import com.azure.cosmos.models.IndexingPolicy;
@@ -48,6 +53,7 @@
4853
import com.azure.cosmos.models.SqlQuerySpec;
4954
import com.azure.cosmos.models.ThroughputProperties;
5055
import com.azure.cosmos.util.CosmosPagedFlux;
56+
5157
import com.fasterxml.jackson.databind.JsonNode;
5258
import com.fasterxml.jackson.databind.ObjectMapper;
5359
import com.fasterxml.jackson.databind.node.ObjectNode;
@@ -86,9 +92,9 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
8692

8793
private final String databaseName;
8894

89-
private final String partitionKeyPath;
95+
private String partitionKeyPath;
9096

91-
private final int vectorStoreThroughput;
97+
private int vectorStoreThroughput;
9298

9399
private final long vectorDimensions;
94100

@@ -116,7 +122,16 @@ protected CosmosDBVectorStore(Builder builder) {
116122
this.vectorDimensions = builder.vectorDimensions;
117123
this.metadataFieldsList = builder.metadataFieldsList;
118124

119-
this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block();
125+
try {
126+
this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block();
127+
}
128+
catch (Exception e) {
129+
// likely failed due to RBAC, so database is assumed to be already created
130+
// (and
131+
// if not, it will fail later)
132+
logger.error("Error creating database: {}", e.getMessage());
133+
}
134+
120135
initializeContainer(this.containerName, this.databaseName, this.vectorStoreThroughput, this.vectorDimensions,
121136
this.partitionKeyPath);
122137
}
@@ -130,10 +145,12 @@ private void initializeContainer(String containerName, String databaseName, int
130145

131146
// Set defaults if not provided
132147
if (this.vectorStoreThroughput == 0) {
133-
vectorStoreThroughput = 400;
148+
this.vectorStoreThroughput = 400;
149+
vectorStoreThroughput = this.vectorStoreThroughput;
134150
}
135151
if (this.partitionKeyPath == null) {
136-
partitionKeyPath = "/id";
152+
this.partitionKeyPath = "/id";
153+
partitionKeyPath = this.partitionKeyPath;
137154
}
138155

139156
// handle hierarchical partition key
@@ -222,10 +239,30 @@ public void doAdd(List<Document> documents) {
222239
// Create a list to hold both the CosmosItemOperation and the corresponding
223240
// document ID
224241
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = documents.stream().map(doc -> {
242+
String partitionKeyValue;
243+
244+
if ("/id".equals(this.partitionKeyPath)) {
245+
partitionKeyValue = doc.getId();
246+
}
247+
else if (this.partitionKeyPath.startsWith("/metadata/")) {
248+
// Extract the key, e.g. "/metadata/country" -> "country"
249+
String metadataKey = this.partitionKeyPath.substring("/metadata/".length());
250+
Object value = doc.getMetadata() != null ? doc.getMetadata().get(metadataKey) : null;
251+
if (value == null) {
252+
throw new IllegalArgumentException(
253+
"Partition key '" + metadataKey + "' not found in document metadata.");
254+
}
255+
partitionKeyValue = value.toString();
256+
}
257+
else {
258+
throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath);
259+
}
260+
225261
CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
226-
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId()));
227-
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
262+
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))),
263+
new PartitionKey(partitionKeyValue)); // Pair the document ID
228264
// with the operation
265+
return new ImmutablePair<>(doc.getId(), operation);
229266
}).toList();
230267

231268
try {
@@ -272,20 +309,60 @@ public void doAdd(List<Document> documents) {
272309
public void doDelete(List<String> idList) {
273310
try {
274311
// Convert the list of IDs into bulk delete operations
275-
List<CosmosItemOperation> itemOperations = idList.stream()
276-
.map(id -> CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(id)))
277-
.collect(Collectors.toList());
312+
List<CosmosItemOperation> itemOperations = idList.stream().map(id -> {
313+
String partitionKeyValue;
314+
315+
if ("/id".equals(this.partitionKeyPath)) {
316+
partitionKeyValue = id;
317+
}
318+
319+
else if (this.partitionKeyPath.startsWith("/metadata/")) {
320+
// Will be inefficient for large numbers of documents but there is no
321+
// other way to get the partition key value
322+
// with current method signature. Ideally, we should be able to pass
323+
// the partition key value directly.
324+
String metadataKey = this.partitionKeyPath.substring("/metadata/".length());
325+
326+
// Run a reactive query to fetch the document by ID
327+
String query = String.format("SELECT * FROM c WHERE c.id = '%s'", id);
328+
CosmosPagedFlux<JsonNode> queryFlux = this.container.queryItems(query,
329+
new CosmosQueryRequestOptions(), JsonNode.class);
330+
331+
// Block to retrieve the first page synchronously
332+
List<JsonNode> documents = queryFlux.byPage(1).blockFirst().getResults();
333+
334+
if (documents == null || documents.isEmpty()) {
335+
throw new IllegalArgumentException("No document found for id: " + id);
336+
}
337+
338+
JsonNode document = documents.get(0);
339+
JsonNode metadataNode = document.get("metadata");
340+
341+
if (metadataNode == null || metadataNode.get(metadataKey) == null) {
342+
throw new IllegalArgumentException("Partition key '" + metadataKey
343+
+ "' not found in metadata for document with id: " + id);
344+
}
345+
346+
partitionKeyValue = metadataNode.get(metadataKey).asText();
347+
}
348+
else {
349+
throw new IllegalArgumentException("Unsupported partition key path: " + this.partitionKeyPath);
350+
}
351+
352+
return CosmosBulkOperations.getDeleteItemOperation(id, new PartitionKey(partitionKeyValue));
353+
}).collect(Collectors.toList());
278354

279355
// Execute bulk delete operations synchronously by using blockLast() on the
280356
// Flux
281357
this.container.executeBulkOperations(Flux.fromIterable(itemOperations))
282358
.doOnNext(response -> logger.info("Document deleted with status: {}",
283359
response.getResponse().getStatusCode()))
284360
.doOnError(error -> logger.error("Error deleting document: {}", error.getMessage()))
285-
.blockLast(); // This will block until all operations have finished
361+
.blockLast();
286362
}
287363
catch (Exception e) {
288-
logger.error("Exception while deleting documents: {}", e.getMessage());
364+
logger.error("Exception while deleting documents: {}", e.getMessage(), e);
365+
throw e;
289366
}
290367
}
291368

@@ -347,9 +424,26 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
347424
.flatMap(page -> Flux.fromIterable(page.getResults()))
348425
.collectList()
349426
.block();
427+
428+
// Collect metadata fields from the documents
429+
Map<String, Object> docFields = new HashMap<>();
430+
for (var doc : documents) {
431+
JsonNode metadata = doc.get("metadata");
432+
metadata.fieldNames().forEachRemaining(field -> {
433+
JsonNode value = metadata.get(field);
434+
Object parsedValue = value.isTextual() ? value.asText() : value.isNumber() ? value.numberValue()
435+
: value.isBoolean() ? value.booleanValue() : value.toString();
436+
docFields.put(field, parsedValue);
437+
});
438+
}
439+
350440
// Convert JsonNode to Document
351441
List<Document> docs = documents.stream()
352-
.map(doc -> Document.builder().id(doc.get("id").asText()).text(doc.get("content").asText()).build())
442+
.map(doc -> Document.builder()
443+
.id(doc.get("id").asText())
444+
.text(doc.get("content").asText())
445+
.metadata(docFields)
446+
.build())
353447
.collect(Collectors.toList());
354448

355449
return docs != null ? docs : List.of();
@@ -474,7 +568,7 @@ public Builder vectorDimensions(long vectorDimensions) {
474568
* @return the builder instance
475569
*/
476570
public Builder metadataFields(List<String> metadataFieldsList) {
477-
this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(this.metadataFieldsList)
571+
this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(metadataFieldsList)
478572
: new ArrayList<>();
479573
return this;
480574
}

‎vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStoreIT.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.azure.cosmos.CosmosAsyncClient;
2626
import com.azure.cosmos.CosmosAsyncContainer;
2727
import com.azure.cosmos.CosmosClientBuilder;
28+
import com.azure.identity.DefaultAzureCredentialBuilder;
2829
import org.junit.jupiter.api.BeforeEach;
2930
import org.junit.jupiter.api.Test;
3031
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@@ -50,7 +51,6 @@
5051
* @since 1.0.0
5152
*/
5253
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
53-
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+")
5454
public class CosmosDBVectorStoreIT {
5555

5656
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
@@ -141,6 +141,11 @@ void testSimilaritySearchWithFilter() {
141141

142142
assertThat(results).hasSize(2);
143143
assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
144+
for (Document doc : results) {
145+
assertThat(doc.getMetadata().get("country")).isIn("UK", "NL");
146+
assertThat(doc.getMetadata().get("year")).isIn(2021, 2022);
147+
assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia");
148+
}
144149

145150
List<Document> results2 = this.vectorStore.similaritySearch(SearchRequest.builder()
146151
.query("The World")
@@ -199,7 +204,7 @@ public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel em
199204
@Bean
200205
public CosmosAsyncClient cosmosClient() {
201206
return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT"))
202-
.key(System.getenv("AZURE_COSMOSDB_KEY"))
207+
.credential(new DefaultAzureCredentialBuilder().build())
203208
.userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
204209
.gatewayMode()
205210
.buildAsyncClient();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore.cosmosdb;
18+
19+
import com.azure.cosmos.CosmosAsyncClient;
20+
import com.azure.cosmos.CosmosAsyncContainer;
21+
import com.azure.cosmos.CosmosClientBuilder;
22+
import com.azure.identity.DefaultAzureCredentialBuilder;
23+
import org.junit.jupiter.api.BeforeEach;
24+
import org.junit.jupiter.api.Test;
25+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
26+
import org.springframework.ai.document.Document;
27+
import org.springframework.ai.embedding.EmbeddingModel;
28+
import org.springframework.ai.transformers.TransformersEmbeddingModel;
29+
import org.springframework.ai.vectorstore.SearchRequest;
30+
import org.springframework.ai.vectorstore.VectorStore;
31+
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
32+
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
33+
import org.springframework.boot.SpringBootConfiguration;
34+
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
35+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
36+
import org.springframework.context.annotation.Bean;
37+
38+
import java.util.HashMap;
39+
import java.util.List;
40+
import java.util.Map;
41+
import java.util.Optional;
42+
import java.util.UUID;
43+
44+
import static org.assertj.core.api.Assertions.assertThat;
45+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
46+
47+
/**
48+
* @author Theo van Kraay
49+
* @author Thomas Vitale
50+
* @since 1.0.0
51+
*/
52+
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
53+
public class CosmosDBVectorStoreWithMetadataPartitionKeyIT {
54+
55+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
56+
.withUserConfiguration(TestApplication.class);
57+
58+
private VectorStore vectorStore;
59+
60+
@BeforeEach
61+
public void setup() {
62+
this.contextRunner.run(context -> this.vectorStore = context.getBean(VectorStore.class));
63+
}
64+
65+
@Test
66+
public void testAddSearchAndDeleteDocuments() {
67+
68+
// Create a sample document
69+
Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1"));
70+
assertThatThrownBy(() -> this.vectorStore.add(List.of(document1))).isInstanceOf(Exception.class)
71+
.hasMessageContaining("Partition key 'country' not found in document metadata.");
72+
73+
Document document2 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("country", "UK"));
74+
this.vectorStore.add(List.of(document2));
75+
76+
// Perform a similarity search
77+
List<Document> results = this.vectorStore
78+
.similaritySearch(SearchRequest.builder().query("Sample content1").topK(1).build());
79+
80+
// Verify the search results
81+
assertThat(results).isNotEmpty();
82+
assertThat(results.get(0).getId()).isEqualTo(document2.getId());
83+
84+
// Remove the documents from the vector store
85+
this.vectorStore.delete(List.of(document2.getId()));
86+
87+
// Perform a similarity search again
88+
List<Document> results2 = this.vectorStore
89+
.similaritySearch(SearchRequest.builder().query("Sample content").topK(1).build());
90+
91+
// Verify the search results
92+
assertThat(results2).isEmpty();
93+
94+
}
95+
96+
@Test
97+
void testSimilaritySearchWithFilter() {
98+
99+
// Insert documents using vectorStore.add
100+
Map<String, Object> metadata1;
101+
metadata1 = new HashMap<>();
102+
metadata1.put("country", "UK");
103+
metadata1.put("year", 2021);
104+
metadata1.put("city", "London");
105+
106+
Map<String, Object> metadata2;
107+
metadata2 = new HashMap<>();
108+
metadata2.put("country", "NL");
109+
metadata2.put("year", 2022);
110+
metadata2.put("city", "Amsterdam");
111+
112+
Map<String, Object> metadata3;
113+
metadata3 = new HashMap<>();
114+
metadata3.put("country", "US");
115+
metadata3.put("year", 2019);
116+
metadata3.put("city", "Sofia");
117+
118+
Map<String, Object> metadata4;
119+
metadata4 = new HashMap<>();
120+
metadata4.put("country", "US");
121+
metadata4.put("year", 2020);
122+
metadata4.put("city", "Sofia");
123+
124+
Document document1 = new Document("1", "A document about the UK", metadata1);
125+
Document document2 = new Document("2", "A document about the Netherlands", metadata2);
126+
Document document3 = new Document("3", "A document about the US", metadata3);
127+
Document document4 = new Document("4", "A document about the US", metadata4);
128+
129+
this.vectorStore.add(List.of(document1, document2, document3, document4));
130+
FilterExpressionBuilder b = new FilterExpressionBuilder();
131+
List<Document> results = this.vectorStore.similaritySearch(SearchRequest.builder()
132+
.query("The World")
133+
.topK(10)
134+
.filterExpression((b.in("country", "UK", "NL")).build())
135+
.build());
136+
137+
assertThat(results).hasSize(2);
138+
assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
139+
for (Document doc : results) {
140+
assertThat(doc.getMetadata().get("country")).isIn("UK", "NL");
141+
assertThat(doc.getMetadata().get("year")).isIn(2021, 2022);
142+
assertThat(doc.getMetadata().get("city")).isIn("London", "Amsterdam").isNotEqualTo("Sofia");
143+
}
144+
145+
List<Document> results2 = this.vectorStore.similaritySearch(SearchRequest.builder()
146+
.query("The World")
147+
.topK(10)
148+
.filterExpression(
149+
b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())
150+
.build());
151+
152+
assertThat(results2).hasSize(1);
153+
assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1");
154+
155+
List<Document> results3 = this.vectorStore.similaritySearch(SearchRequest.builder()
156+
.query("The World")
157+
.topK(10)
158+
.filterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())
159+
.build());
160+
161+
assertThat(results3).hasSize(1);
162+
assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4");
163+
164+
this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId()));
165+
166+
// Perform a similarity search again
167+
List<Document> results4 = this.vectorStore
168+
.similaritySearch(SearchRequest.builder().query("The World").topK(1).build());
169+
170+
// Verify the search results
171+
assertThat(results4).isEmpty();
172+
}
173+
174+
@Test
175+
void getNativeClientTest() {
176+
this.contextRunner.run(context -> {
177+
CosmosDBVectorStore vectorStore = context.getBean(CosmosDBVectorStore.class);
178+
Optional<CosmosAsyncContainer> nativeClient = vectorStore.getNativeClient();
179+
assertThat(nativeClient).isPresent();
180+
});
181+
}
182+
183+
@SpringBootConfiguration
184+
@EnableAutoConfiguration
185+
public static class TestApplication {
186+
187+
@Bean
188+
public VectorStore vectorStore(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel,
189+
VectorStoreObservationConvention convention) {
190+
return CosmosDBVectorStore.builder(cosmosClient, embeddingModel)
191+
.databaseName("test-database")
192+
.containerName("test-container-metadata-partition-key")
193+
.metadataFields(List.of("country", "year", "city"))
194+
.partitionKeyPath("/metadata/country")
195+
.vectorStoreThroughput(1000)
196+
.customObservationConvention(convention)
197+
.build();
198+
}
199+
200+
@Bean
201+
public CosmosAsyncClient cosmosClient() {
202+
return new CosmosClientBuilder().endpoint(System.getenv("AZURE_COSMOSDB_ENDPOINT"))
203+
.credential(new DefaultAzureCredentialBuilder().build())
204+
.userAgentSuffix("SpringAI-CDBNoSQL-VectorStore")
205+
.gatewayMode()
206+
.buildAsyncClient();
207+
}
208+
209+
@Bean
210+
public EmbeddingModel embeddingModel() {
211+
return new TransformersEmbeddingModel();
212+
}
213+
214+
@Bean
215+
public VectorStoreObservationConvention observationConvention() {
216+
// Replace with an actual observation convention or a mock if needed
217+
return new VectorStoreObservationConvention() {
218+
219+
};
220+
}
221+
222+
}
223+
224+
}

0 commit comments

Comments
 (0)
Please sign in to comment.