Skip to content

update to adapt change in version of 0.8.0 #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/main/java/com/example/carina/config/CarinaConfig.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package com.example.carina.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.retriever.VectorStoreRetriever;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -21,8 +19,11 @@ public VectorStore vectorStore(EmbeddingClient embeddingClient, JdbcTemplate jdb
}

@Bean
public VectorStoreRetriever vectorStoreRetriever(VectorStore vectorStore) {
return new VectorStoreRetriever(vectorStore, 4, 0.75);
public SearchRequest searchRequest() {
SearchRequest searchRequest = SearchRequest.defaults();
searchRequest.withTopK(4);
searchRequest.withSimilarityThreshold(0.75);
return searchRequest;
}


Expand Down
37 changes: 22 additions & 15 deletions src/main/java/com/example/carina/qa/QAService.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.client.AiClient;
import org.springframework.ai.client.AiResponse;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.retriever.VectorStoreRetriever;

import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
Expand All @@ -30,14 +33,17 @@ public class QAService {
@Value("classpath:/prompts/system-chatbot.st")
private Resource chatbotSystemPromptResource;

private final AiClient aiClient;
private final ChatClient chatClient;

private final VectorStore vectorStore;

private final VectorStoreRetriever vectorStoreRetriever;
private final SearchRequest searchRequest;

@Autowired
public QAService(AiClient aiClient, VectorStoreRetriever vectorStoreRetriever) {
this.aiClient = aiClient;
this.vectorStoreRetriever = vectorStoreRetriever;
public QAService(ChatClient chatClient, VectorStore vectorStore, SearchRequest searchRequest) {
this.chatClient = chatClient;
this.vectorStore = vectorStore;
this.searchRequest = searchRequest;
}

public String generate(String message, boolean stuffit) {
Expand All @@ -46,15 +52,16 @@ public String generate(String message, boolean stuffit) {
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));

logger.info("Asking AI model to reply to question.");
AiResponse aiResponse = aiClient.generate(prompt);
ChatResponse chatResponse = chatClient.call(prompt);
logger.info("AI responded.");
return aiResponse.getGeneration().getContent();
return chatResponse.getResult().getOutput().getContent();
}

private Message getSystemMessage(String message, boolean stuffit) {
if (stuffit) {
logger.info("Retrieving relevant documents");
List<Document> similarDocuments = vectorStoreRetriever.retrieve(message);
searchRequest.withQuery(message);
List<Document> similarDocuments = vectorStore.similaritySearch(searchRequest);
logger.info(String.format("Found %s relevant documents.", similarDocuments.size()));
String documents = similarDocuments.stream().map(entry -> entry.getContent()).collect(Collectors.joining("\n"));
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.qaSystemPromptResource);
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/com/example/carina/simple/SimpleAiController.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.example.carina.simple;

import org.springframework.ai.client.AiClient;
import org.springframework.ai.chat.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
Expand All @@ -9,16 +9,16 @@
@RestController
public class SimpleAiController {

private final AiClient aiClient;
private final ChatClient chatClient;

@Autowired
public SimpleAiController(AiClient aiClient) {
this.aiClient = aiClient;
public SimpleAiController(ChatClient chatClient) {
this.chatClient = chatClient;
}

@GetMapping("/ai/simple")
public Completion completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
return new Completion(aiClient.generate(message));
return new Completion(chatClient.call(message));
}

}