Skip to content

Commit e68bdeb

Browse files
lemon0029markpollack
authored andcommitted
Add OllamaClient implementation of AiClient
1 parent c156631 commit e68bdeb

File tree

6 files changed

+491
-0
lines changed

6 files changed

+491
-0
lines changed

Diff for: pom.xml

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<module>spring-ai-core</module>
1717
<module>spring-ai-openai</module>
1818
<module>spring-ai-azure-openai</module>
19+
<module>spring-ai-ollama</module>
1920
<module>spring-ai-spring-boot-autoconfigure</module>
2021
<module>spring-ai-spring-boot-starters/spring-ai-starter-openai</module>
2122
<module>spring-ai-spring-boot-starters/spring-ai-starter-azure-openai</module>

Diff for: spring-ai-ollama/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
## Ollama
2+
3+
Ollama lets you ge tup an running with large language models locally
4+
5+
Refer to the official [README](https://github.com/jmorganca/ollama) to get started.
6+
7+
Note, installing `ollama run llama2` will download a 4GB docker image.
8+
9+
You can run the disabled test in `OllamaClientTests.java` to kick the tires.

Diff for: spring-ai-ollama/pom.xml

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
<parent>
7+
<groupId>org.springframework.experimental.ai</groupId>
8+
<artifactId>spring-ai</artifactId>
9+
<version>0.7.0-SNAPSHOT</version>
10+
</parent>
11+
12+
<artifactId>spring-ai-ollama</artifactId>
13+
<packaging>jar</packaging>
14+
<name>Spring AI Ollama</name>
15+
<description>Ollama support</description>
16+
17+
<properties>
18+
<maven.compiler.source>17</maven.compiler.source>
19+
<maven.compiler.target>17</maven.compiler.target>
20+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
21+
</properties>
22+
23+
<dependencies>
24+
<dependency>
25+
<groupId>org.springframework.experimental.ai</groupId>
26+
<artifactId>spring-ai-core</artifactId>
27+
<version>${project.parent.version}</version>
28+
</dependency>
29+
30+
<dependency>
31+
<groupId>org.springframework.boot</groupId>
32+
<artifactId>spring-boot-starter-logging</artifactId>
33+
</dependency>
34+
35+
<!-- test dependencies -->
36+
<dependency>
37+
<groupId>org.springframework.boot</groupId>
38+
<artifactId>spring-boot-starter-test</artifactId>
39+
<scope>test</scope>
40+
</dependency>
41+
</dependencies>
42+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
package org.springframework.ai.ollama.client;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import org.slf4j.Logger;
6+
import org.slf4j.LoggerFactory;
7+
import org.springframework.ai.client.AiClient;
8+
import org.springframework.ai.client.AiResponse;
9+
import org.springframework.ai.client.Generation;
10+
import org.springframework.ai.prompt.Prompt;
11+
import org.springframework.util.CollectionUtils;
12+
13+
import java.io.BufferedReader;
14+
import java.io.IOException;
15+
import java.io.InputStream;
16+
import java.io.InputStreamReader;
17+
import java.net.URI;
18+
import java.net.http.HttpClient;
19+
import java.net.http.HttpRequest;
20+
import java.net.http.HttpResponse;
21+
import java.time.Duration;
22+
import java.util.*;
23+
import java.util.function.Consumer;
24+
import java.util.stream.Collectors;
25+
26+
/**
27+
* A client implementation for interacting with Ollama Service. This class acts as an
28+
* interface between the application and the Ollama AI Service, handling request creation,
29+
* communication, and response processing.
30+
*
31+
* @author nullptr
32+
*/
33+
public class OllamaClient implements AiClient {
34+
35+
/** Logger for logging the events and messages. */
36+
private static final Logger log = LoggerFactory.getLogger(OllamaClient.class);
37+
38+
/** Mapper for JSON serialization and deserialization. */
39+
private static final ObjectMapper jsonMapper = new ObjectMapper();
40+
41+
/** HTTP client for making asynchronous calls to the Ollama Service. */
42+
private static final HttpClient httpClient = HttpClient.newBuilder().build();
43+
44+
/** Base URL of the Ollama Service. */
45+
private final String baseUrl;
46+
47+
/** Name of the model to be used for the AI service. */
48+
private final String model;
49+
50+
/** Optional callback to handle individual generation results. */
51+
private Consumer<OllamaGenerateResult> simpleCallback;
52+
53+
/**
54+
* Constructs an OllamaClient with the specified base URL and model.
55+
* @param baseUrl Base URL of the Ollama Service.
56+
* @param model Model specification for the AI service.
57+
*/
58+
public OllamaClient(String baseUrl, String model) {
59+
this.baseUrl = baseUrl;
60+
this.model = model;
61+
}
62+
63+
/**
64+
* Constructs an OllamaClient with the specified base URL, model, and a callback.
65+
* @param baseUrl Base URL of the Ollama Service.
66+
* @param model Model specification for the AI service.
67+
* @param simpleCallback Callback to handle individual generation results.
68+
*/
69+
public OllamaClient(String baseUrl, String model, Consumer<OllamaGenerateResult> simpleCallback) {
70+
this(baseUrl, model);
71+
this.simpleCallback = simpleCallback;
72+
}
73+
74+
@Override
75+
public AiResponse generate(Prompt prompt) {
76+
validatePrompt(prompt);
77+
78+
HttpRequest request = buildHttpRequest(prompt);
79+
var response = sendRequest(request);
80+
81+
List<OllamaGenerateResult> results = readGenerateResults(response.body());
82+
return getAiResponse(results);
83+
}
84+
85+
/**
86+
* Validates the provided prompt.
87+
* @param prompt The prompt to validate.
88+
*/
89+
protected void validatePrompt(Prompt prompt) {
90+
if (CollectionUtils.isEmpty(prompt.getMessages())) {
91+
throw new RuntimeException("The prompt message cannot be empty.");
92+
}
93+
94+
if (prompt.getMessages().size() > 1) {
95+
log.warn("Only the first prompt message will be used; subsequent messages will be ignored.");
96+
}
97+
}
98+
99+
/**
100+
* Constructs an HTTP request for the provided prompt.
101+
* @param prompt The prompt for which the request needs to be built.
102+
* @return The constructed HttpRequest.
103+
*/
104+
protected HttpRequest buildHttpRequest(Prompt prompt) {
105+
String requestBody = getGenerateRequestBody(prompt.getMessages().get(0).getContent());
106+
107+
// remove the suffix '/' if necessary
108+
String url = !this.baseUrl.endsWith("/") ? this.baseUrl : this.baseUrl.substring(0, this.baseUrl.length() - 1);
109+
110+
return HttpRequest.newBuilder()
111+
.uri(URI.create("%s/api/generate".formatted(url)))
112+
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
113+
.timeout(Duration.ofMinutes(5L))
114+
.build();
115+
}
116+
117+
/**
118+
* Sends the constructed HttpRequest and retrieves the HttpResponse.
119+
* @param request The HttpRequest to be sent.
120+
* @return HttpResponse containing the response data.
121+
*/
122+
protected HttpResponse<InputStream> sendRequest(HttpRequest request) {
123+
var response = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()).join();
124+
if (response.statusCode() != 200) {
125+
throw new RuntimeException("Ollama call returned an unexpected status: " + response.statusCode());
126+
}
127+
return response;
128+
}
129+
130+
/**
131+
* Serializes the prompt into a request body for the Ollama API call.
132+
* @param prompt The prompt to be serialized.
133+
* @return Serialized request body as a String.
134+
*/
135+
private String getGenerateRequestBody(String prompt) {
136+
var data = Map.of("model", model, "prompt", prompt);
137+
try {
138+
return jsonMapper.writeValueAsString(data);
139+
}
140+
catch (JsonProcessingException ex) {
141+
throw new RuntimeException("Failed to serialize the prompt to JSON", ex);
142+
}
143+
144+
}
145+
146+
/**
147+
* Reads and processes the results from the InputStream provided by the Ollama
148+
* Service.
149+
* @param inputStream InputStream containing the results from the Ollama Service.
150+
* @return List of OllamaGenerateResult.
151+
*/
152+
protected List<OllamaGenerateResult> readGenerateResults(InputStream inputStream) {
153+
try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) {
154+
var results = new ArrayList<OllamaGenerateResult>();
155+
String line;
156+
while ((line = bufferedReader.readLine()) != null) {
157+
processResponseLine(line, results);
158+
}
159+
return results;
160+
}
161+
catch (IOException e) {
162+
throw new RuntimeException("Error parsing Ollama generation response.", e);
163+
}
164+
}
165+
166+
/**
167+
* Processes a single line from the Ollama response.
168+
* @param line The line to be processed.
169+
* @param results List to which parsed results will be added.
170+
*/
171+
protected void processResponseLine(String line, List<OllamaGenerateResult> results) {
172+
if (line.isBlank())
173+
return;
174+
175+
log.debug("Received ollama generate response: {}", line);
176+
177+
OllamaGenerateResult result;
178+
try {
179+
result = jsonMapper.readValue(line, OllamaGenerateResult.class);
180+
}
181+
catch (IOException e) {
182+
throw new RuntimeException("Error parsing response line from Ollama.", e);
183+
}
184+
185+
if (result.getModel() == null || result.getDone() == null) {
186+
throw new IllegalStateException("Received invalid data from Ollama. Model = " + result.getModel()
187+
+ " , Done = " + result.getDone());
188+
189+
}
190+
191+
if (simpleCallback != null) {
192+
simpleCallback.accept(result);
193+
}
194+
195+
results.add(result);
196+
}
197+
198+
/**
199+
* Converts the list of OllamaGenerateResult into a structured AiResponse.
200+
* @param results List of OllamaGenerateResult.
201+
* @return Formulated AiResponse.
202+
*/
203+
protected AiResponse getAiResponse(List<OllamaGenerateResult> results) {
204+
var ollamaResponse = results.stream()
205+
.filter(Objects::nonNull)
206+
.filter(it -> it.getResponse() != null && !it.getResponse().isBlank())
207+
.filter(it -> it.getDone() != null)
208+
.map(OllamaGenerateResult::getResponse)
209+
.collect(Collectors.joining(""));
210+
211+
var generation = new Generation(ollamaResponse);
212+
213+
// TODO investigate mapping of additional metadata/runtime info to the response.
214+
// Determine if should be top
215+
// level map vs. nested map
216+
return new AiResponse(Collections.singletonList(generation), Map.of("ollama-generate-results", results));
217+
}
218+
219+
/**
220+
* @return Model name for the AI service.
221+
*/
222+
public String getModel() {
223+
return model;
224+
}
225+
226+
/**
227+
* @return Base URL of the Ollama Service.
228+
*/
229+
public String getBaseUrl() {
230+
return baseUrl;
231+
}
232+
233+
/**
234+
* @return Callback that handles individual generation results.
235+
*/
236+
public Consumer<OllamaGenerateResult> getSimpleCallback() {
237+
return simpleCallback;
238+
}
239+
240+
/**
241+
* Sets the callback that handles individual generation results.
242+
* @param simpleCallback The callback to be set.
243+
*/
244+
public void setSimpleCallback(Consumer<OllamaGenerateResult> simpleCallback) {
245+
this.simpleCallback = simpleCallback;
246+
}
247+
248+
}

0 commit comments

Comments
 (0)