From c585169626c45adfc1e6df1ab79dedf7a62e0484 Mon Sep 17 00:00:00 2001 From: Xiaonan Wei Date: Mon, 4 Mar 2024 11:34:54 -0600 Subject: [PATCH] update to adapt change in version of 0.8.0 --- .../example/carina/config/CarinaConfig.java | 11 +++--- .../java/com/example/carina/qa/QAService.java | 37 +++++++++++-------- .../carina/simple/SimpleAiController.java | 10 ++--- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/main/java/com/example/carina/config/CarinaConfig.java b/src/main/java/com/example/carina/config/CarinaConfig.java index fedee46..74687f8 100644 --- a/src/main/java/com/example/carina/config/CarinaConfig.java +++ b/src/main/java/com/example/carina/config/CarinaConfig.java @@ -1,10 +1,8 @@ package com.example.carina.config; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.retriever.VectorStoreRetriever; import org.springframework.ai.vectorstore.PgVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -21,8 +19,11 @@ public VectorStore vectorStore(EmbeddingClient embeddingClient, JdbcTemplate jdb } @Bean - public VectorStoreRetriever vectorStoreRetriever(VectorStore vectorStore) { - return new VectorStoreRetriever(vectorStore, 4, 0.75); + public SearchRequest searchRequest() { + SearchRequest searchRequest = SearchRequest.defaults(); + searchRequest.withTopK(4); + searchRequest.withSimilarityThreshold(0.75); + return searchRequest; } diff --git a/src/main/java/com/example/carina/qa/QAService.java b/src/main/java/com/example/carina/qa/QAService.java index 3f91b74..dd58499 100644 --- a/src/main/java/com/example/carina/qa/QAService.java +++ b/src/main/java/com/example/carina/qa/QAService.java @@ -2,14 +2,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.client.AiClient; -import org.springframework.ai.client.AiResponse; + +import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; -import org.springframework.ai.retriever.VectorStoreRetriever; + +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; @@ -30,14 +33,17 @@ public class QAService { @Value("classpath:/prompts/system-chatbot.st") private Resource chatbotSystemPromptResource; - private final AiClient aiClient; + private final ChatClient chatClient; + + private final VectorStore vectorStore; - private final VectorStoreRetriever vectorStoreRetriever; + private final SearchRequest searchRequest; @Autowired - public QAService(AiClient aiClient, VectorStoreRetriever vectorStoreRetriever) { - this.aiClient = aiClient; - this.vectorStoreRetriever = vectorStoreRetriever; + public QAService(ChatClient chatClient, VectorStore vectorStore, SearchRequest searchRequest) { + this.chatClient = chatClient; + this.vectorStore = vectorStore; + this.searchRequest = searchRequest; } public String generate(String message, boolean stuffit) { @@ -46,15 +52,16 @@ public String generate(String message, boolean stuffit) { Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); logger.info("Asking AI model to reply to question."); - AiResponse aiResponse = aiClient.generate(prompt); + ChatResponse chatResponse = chatClient.call(prompt); logger.info("AI responded."); - return aiResponse.getGeneration().getContent(); + return chatResponse.getResult().getOutput().getContent(); } private Message getSystemMessage(String message, boolean stuffit) { if (stuffit) { logger.info("Retrieving relevant documents"); - List similarDocuments = vectorStoreRetriever.retrieve(message); + searchRequest.withQuery(message); + List similarDocuments = vectorStore.similaritySearch(searchRequest); logger.info(String.format("Found %s relevant documents.", similarDocuments.size())); String documents = similarDocuments.stream().map(entry -> entry.getContent()).collect(Collectors.joining("\n")); SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.qaSystemPromptResource); diff --git a/src/main/java/com/example/carina/simple/SimpleAiController.java b/src/main/java/com/example/carina/simple/SimpleAiController.java index 61a8b03..8a8d0b0 100644 --- a/src/main/java/com/example/carina/simple/SimpleAiController.java +++ b/src/main/java/com/example/carina/simple/SimpleAiController.java @@ -1,6 +1,6 @@ package com.example.carina.simple; -import org.springframework.ai.client.AiClient; +import org.springframework.ai.chat.ChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestParam; @@ -9,16 +9,16 @@ @RestController public class SimpleAiController { - private final AiClient aiClient; + private final ChatClient chatClient; @Autowired - public SimpleAiController(AiClient aiClient) { - this.aiClient = aiClient; + public SimpleAiController(ChatClient chatClient) { + this.chatClient = chatClient; } @GetMapping("/ai/simple") public Completion completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return new Completion(aiClient.generate(message)); + return new Completion(chatClient.call(message)); } }