Skip to content

Commit a1db00a

Browse files
linarkouilayaperumalg
authored andcommitted
Fixed message order for JDBC Chat Memory
- set timestamp manually to prevent saving equal timestamps while batch insert - return DESC order into get query (in `JdbcChatMemory`) - Update similar changes for JdbcChatMemoryRepository This PR solves some problems with message ordering: JdbcChatMemory fetches rows in DESC order, and MessageChatMemoryAdvisor get and add all messages in that order - from last to first. So messages list needs to be reversed. After batch insert without specifying timestamp manually all rows have the same timestamp (because database sets current_timestamp by default). Sooo after fetching rows from database the message order is unpredictable - they have same timestamp. Signed-off-by: Linar Abzaltdinov <[email protected]>
1 parent f5ac94c commit a1db00a

File tree

6 files changed

+90
-27
lines changed

6 files changed

+90
-27
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
package org.springframework.ai.model.chat.memory.jdbc.autoconfigure;
1818

19+
import java.util.List;
20+
import java.util.UUID;
21+
1922
import org.junit.jupiter.api.Test;
23+
2024
import org.springframework.ai.chat.memory.ChatMemory;
2125
import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory;
2226
import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository;
@@ -29,14 +33,12 @@
2933
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
3034
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
3135

32-
import java.util.List;
33-
import java.util.UUID;
34-
3536
import static org.assertj.core.api.Assertions.assertThat;
3637

3738
/**
3839
* @author Jonathan Leijendekker
3940
* @author Thomas Vitale
41+
* @author Linar Abzaltdinov
4042
*/
4143
class JdbcChatMemoryPostgresqlAutoConfigurationIT {
4244

@@ -87,6 +89,12 @@ void addGetAndClear_shouldAllExecute() {
8789
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1);
8890
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage));
8991

92+
var assistantMessage = new AssistantMessage("Message from the assistant");
93+
94+
chatMemory.add(conversationId, List.of(assistantMessage));
95+
96+
assertThat(chatMemory.get(conversationId)).hasSize(2);
97+
assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage, assistantMessage));
9098
chatMemory.clear(conversationId);
9199

92100
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty();
@@ -142,6 +150,13 @@ void useAutoConfiguredChatMemoryWithJdbc() {
142150
assertThat(chatMemory.get(conversationId)).hasSize(1);
143151
assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage));
144152

153+
var assistantMessage = new AssistantMessage("Message from the assistant");
154+
155+
chatMemory.add(conversationId, List.of(assistantMessage));
156+
157+
assertThat(chatMemory.get(conversationId)).hasSize(2);
158+
assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage, assistantMessage));
159+
145160
chatMemory.clear(conversationId);
146161

147162
assertThat(chatMemory.get(conversationId)).isEmpty();

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java

+18-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import java.sql.PreparedStatement;
2020
import java.sql.ResultSet;
2121
import java.sql.SQLException;
22+
import java.sql.Timestamp;
23+
import java.time.Instant;
24+
import java.util.Collections;
2225
import java.util.List;
26+
import java.util.concurrent.atomic.AtomicLong;
2327

2428
import org.springframework.ai.chat.memory.ChatMemory;
2529
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
@@ -38,6 +42,7 @@
3842
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
3943
*
4044
* @author Jonathan Leijendekker
45+
* @author Linar Abzaltdinov
4146
* @since 1.0.0
4247
* @deprecated in favor of building a {@link MessageWindowChatMemory} (or other
4348
* {@link ChatMemory} implementations) with a {@link JdbcChatMemoryRepository} instance.
@@ -46,10 +51,10 @@
4651
public class JdbcChatMemory implements ChatMemory {
4752

4853
private static final String QUERY_ADD = """
49-
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
54+
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)""";
5055

5156
private static final String QUERY_GET = """
52-
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" LIMIT ?""";
57+
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";
5358

5459
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
5560

@@ -70,23 +75,31 @@ public void add(String conversationId, List<Message> messages) {
7075

7176
@Override
7277
public List<Message> get(String conversationId, int lastN) {
73-
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
78+
List<Message> messages = this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
79+
Collections.reverse(messages);
80+
return messages;
7481
}
7582

7683
@Override
7784
public void clear(String conversationId) {
7885
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
7986
}
8087

81-
private record AddBatchPreparedStatement(String conversationId,
82-
List<Message> messages) implements BatchPreparedStatementSetter {
88+
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
89+
AtomicLong instantSeq) implements BatchPreparedStatementSetter {
90+
91+
private AddBatchPreparedStatement(String conversationId, List<Message> messages) {
92+
this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli()));
93+
}
94+
8395
@Override
8496
public void setValues(PreparedStatement ps, int i) throws SQLException {
8597
var message = this.messages.get(i);
8698

8799
ps.setString(1, this.conversationId);
88100
ps.setString(2, message.getText());
89101
ps.setString(3, message.getMessageType().name());
102+
ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement()));
90103
}
91104

92105
@Override

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java

+25-10
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,34 @@
1616

1717
package org.springframework.ai.chat.memory.jdbc;
1818

19+
import java.sql.PreparedStatement;
20+
import java.sql.ResultSet;
21+
import java.sql.SQLException;
22+
import java.sql.Timestamp;
23+
import java.time.Instant;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import java.util.concurrent.atomic.AtomicLong;
27+
1928
import org.springframework.ai.chat.memory.ChatMemoryRepository;
20-
import org.springframework.ai.chat.messages.*;
29+
import org.springframework.ai.chat.messages.AssistantMessage;
30+
import org.springframework.ai.chat.messages.Message;
31+
import org.springframework.ai.chat.messages.MessageType;
32+
import org.springframework.ai.chat.messages.SystemMessage;
33+
import org.springframework.ai.chat.messages.ToolResponseMessage;
34+
import org.springframework.ai.chat.messages.UserMessage;
2135
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
2236
import org.springframework.jdbc.core.JdbcTemplate;
2337
import org.springframework.jdbc.core.RowMapper;
2438
import org.springframework.lang.Nullable;
2539
import org.springframework.util.Assert;
2640

27-
import java.sql.PreparedStatement;
28-
import java.sql.ResultSet;
29-
import java.sql.SQLException;
30-
import java.util.ArrayList;
31-
import java.util.List;
32-
3341
/**
3442
* An implementation of {@link ChatMemoryRepository} for JDBC.
3543
*
3644
* @author Jonathan Leijendekker
3745
* @author Thomas Vitale
46+
* @author Linar Abzaltdinov
3847
* @since 1.0.0
3948
*/
4049
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
@@ -44,7 +53,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
4453
""";
4554

4655
private static final String QUERY_ADD = """
47-
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)
56+
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)
4857
""";
4958

5059
private static final String QUERY_GET = """
@@ -93,15 +102,21 @@ public void deleteByConversationId(String conversationId) {
93102
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
94103
}
95104

96-
private record AddBatchPreparedStatement(String conversationId,
97-
List<Message> messages) implements BatchPreparedStatementSetter {
105+
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
106+
AtomicLong instantSeq) implements BatchPreparedStatementSetter {
107+
108+
private AddBatchPreparedStatement(String conversationId, List<Message> messages) {
109+
this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli()));
110+
}
111+
98112
@Override
99113
public void setValues(PreparedStatement ps, int i) throws SQLException {
100114
var message = this.messages.get(i);
101115

102116
ps.setString(1, this.conversationId);
103117
ps.setString(2, message.getText());
104118
ps.setString(3, message.getMessageType().name());
119+
ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement()));
105120
}
106121

107122
@Override

memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory (
22
conversation_id VARCHAR(36) NOT NULL,
33
content TEXT NOT NULL,
44
type VARCHAR(10) NOT NULL,
5-
`timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
5+
`timestamp` TIMESTAMP NOT NULL,
66
CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
77
);
88

99
CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx
10-
ON ai_chat_memory(conversation_id, `timestamp`);
10+
ON ai_chat_memory(conversation_id, `timestamp`);

memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory (
22
conversation_id VARCHAR(36) NOT NULL,
33
content TEXT NOT NULL,
44
type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')),
5-
"timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
6-
);
5+
"timestamp" TIMESTAMP NOT NULL
6+
);
77

88
CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx
9-
ON ai_chat_memory(conversation_id, "timestamp");
9+
ON ai_chat_memory(conversation_id, "timestamp");

memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java

+24-4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
/**
5353
* @author Jonathan Leijendekker
54+
* @author Linar Abzaltdinov
5455
*/
5556
@Testcontainers
5657
class JdbcChatMemoryIT {
@@ -147,10 +148,11 @@ void get_shouldReturnMessages() {
147148
this.contextRunner.run(context -> {
148149
var chatMemory = context.getBean(ChatMemory.class);
149150
var conversationId = UUID.randomUUID().toString();
150-
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
151-
new AssistantMessage("Message from assistant 2 - " + conversationId),
152-
new UserMessage("Message from user - " + conversationId),
153-
new SystemMessage("Message from system - " + conversationId));
151+
var messages = List.<Message>of(new SystemMessage("Message from system - " + conversationId),
152+
new UserMessage("Message from user 1 - " + conversationId),
153+
new AssistantMessage("Message from assistant 1 - " + conversationId),
154+
new UserMessage("Message from user 2 - " + conversationId),
155+
new AssistantMessage("Message from assistant 2 - " + conversationId));
154156

155157
chatMemory.add(conversationId, messages);
156158

@@ -161,6 +163,24 @@ void get_shouldReturnMessages() {
161163
});
162164
}
163165

166+
@Test
167+
void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() {
168+
this.contextRunner.run(context -> {
169+
var chatMemory = context.getBean(ChatMemory.class);
170+
var conversationId = UUID.randomUUID().toString();
171+
var userMessage = new UserMessage("Message from user - " + conversationId);
172+
var assistantMessage = new AssistantMessage("Message from assistant - " + conversationId);
173+
174+
chatMemory.add(conversationId, userMessage);
175+
chatMemory.add(conversationId, assistantMessage);
176+
177+
var results = chatMemory.get(conversationId, Integer.MAX_VALUE);
178+
179+
assertThat(results.size()).isEqualTo(2);
180+
assertThat(results).isEqualTo(List.of(userMessage, assistantMessage));
181+
});
182+
}
183+
164184
@Test
165185
void clear_shouldDeleteMessages() {
166186
this.contextRunner.run(context -> {

0 commit comments

Comments
 (0)