Skip to content

Feature/multi rag #745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/lib/buildPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@ interface buildPromptOptions {
id?: Conversation["_id"];
model: BackendModel;
locals?: App.Locals;
ragContext?: RagContext;
ragContexts?: {
webSearch?: RagContextWebSearch;
pdfChat?: RagContext;
// Add more context types as needed
};
preprompt?: string;
files?: File[];
}

export async function buildPrompt({
messages,
model,
ragContext,
ragContexts,
preprompt,
id,
}: buildPromptOptions): Promise<string> {
if (ragContext) {
const { type: ragType } = ragContext;
messages = RAGs[ragType].buildPrompt(messages, ragContext as RagContextWebSearch);
if (ragContexts) {
for (const [ragKey, ragContext] of Object.entries(ragContexts)) {
if (ragKey == "webSearch" && ragContext) {
messages = RAGs.webSearch.buildPrompt(messages, ragContext as RagContextWebSearch);
}
if (ragKey == "pdfChat" && ragContext) {
messages = RAGs.pdfChat.buildPrompt(messages, ragContext as RagContext);
}
}
}

// section to handle potential files input
Expand Down
2 changes: 1 addition & 1 deletion src/lib/server/endpoints/aws/endpointAws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export async function endpointAws(
return async ({ conversation }) => {
const prompt = await buildPrompt({
messages: conversation.messages,
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
ragContexts: conversation.messages[conversation.messages.length - 1].ragContexts,
preprompt: conversation.preprompt,
model,
});
Expand Down
2 changes: 1 addition & 1 deletion src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export function endpointLlamacpp(
return async ({ conversation }) => {
const prompt = await buildPrompt({
messages: conversation.messages,
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
ragContexts: conversation.messages[conversation.messages.length - 1].ragContexts,
preprompt: conversation.preprompt,
model,
});
Expand Down
2 changes: 1 addition & 1 deletion src/lib/server/endpoints/ollama/endpointOllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
return async ({ conversation }) => {
const prompt = await buildPrompt({
messages: conversation.messages,
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
ragContexts: conversation.messages[conversation.messages.length - 1].ragContexts,
preprompt: conversation.preprompt,
model,
});
Expand Down
6 changes: 3 additions & 3 deletions src/lib/server/endpoints/openai/endpointOai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export async function endpointOai(
model: model.id ?? model.name,
prompt: await buildPrompt({
messages: conversation.messages,
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
ragContexts: conversation.messages[conversation.messages.length - 1].ragContexts,
preprompt: conversation.preprompt,
model,
}),
Expand All @@ -57,9 +57,9 @@ export async function endpointOai(
} else if (completion === "chat_completions") {
return async ({ conversation }) => {
let messages = conversation.messages;
const ragContext = conversation.messages[conversation.messages.length - 1].ragContext;
const ragContext = conversation.messages[conversation.messages.length - 1].ragContexts;

if (ragContext && ragContext.type === "webSearch") {
if (ragContext && ragContext.webSearch) {
const webSearchContext = ragContext as RagContextWebSearch;
const lastMsg = messages.slice(-1)[0];
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
Expand Down
2 changes: 1 addition & 1 deletion src/lib/server/endpoints/tgi/endpointTgi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
return async ({ conversation }) => {
const prompt = await buildPrompt({
messages: conversation.messages,
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
ragContexts: conversation.messages[conversation.messages.length - 1].ragContexts,
preprompt: conversation.preprompt,
model,
id: conversation._id,
Expand Down
12 changes: 4 additions & 8 deletions src/lib/server/rag/rag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@ export interface RAG<T extends RagContext = RagContext> {
}

type RAGUnion = RAG<RagContext> | RAG<RagContextWebSearch>;

// list of all rags
export const RAGs: {
[Key in RAGType]: RAGUnion;
} = {
webSearch: ragWebsearch,
pdfChat: ragPdfchat,
};
namespace RAGs {
export const webSearch = ragWebsearch;
export const pdfChat = ragPdfchat;
}

export default RAGs;
5 changes: 4 additions & 1 deletion src/lib/types/Message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ export type Message = Partial<Timestamps> & {
content: string;
updates?: MessageUpdate[];
webSearchId?: RagContextWebSearch["_id"]; // legacy version
ragContext?: RagContext;
ragContexts?: {
webSearch?: RagContextWebSearch;
pdfChat?: RagContext;
};
score?: -1 | 0 | 1;
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
};
34 changes: 18 additions & 16 deletions src/routes/conversation/[id]/+server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,25 +233,27 @@ export async function POST({ request, locals, params, getClientAddress }) {
}
);

let webSearchResults: RagContextWebSearch | undefined;

if (webSearch) {
webSearchResults = (await RAGs["webSearch"].retrieveRagContext(
conv,
newPrompt,
update
)) as RagContextWebSearch;
}

messages[messages.length - 1].ragContext = webSearchResults;
//
const webSearchResults = webSearch
? await RAGs["webSearch"].retrieveRagContext(conv, newPrompt, update)
: undefined;

let pdfSearchResults: RagContext | undefined;
const pdfSearch = await collections.files.findOne({ filename: `${convId.toString()}-pdf` });
if (pdfSearch) {
pdfSearchResults = await RAGs["pdfChat"].retrieveRagContext(conv, newPrompt, update);
const pdfSearchResults = pdfSearch
? await RAGs["pdfChat"].retrieveRagContext(conv, newPrompt, update)
: undefined;

const lastMessage = messages[messages.length - 1];
lastMessage.ragContexts = lastMessage.ragContexts || {}; // Ensure the object exists

if (webSearchResults) {
lastMessage.ragContexts["webSearch"] = webSearchResults as RagContextWebSearch;
}

messages[messages.length - 1].ragContext = pdfSearchResults;
if (pdfSearchResults) {
lastMessage.ragContexts["pdfChat"] = pdfSearchResults;
}
//

conv.messages = messages;

Expand Down Expand Up @@ -279,7 +281,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
{
from: "assistant",
content: output.token.text.trimStart(),
ragContext: webSearchResults,
ragContexts: lastMessage.ragContexts,
updates: updates,
id: (responseId as Message["id"]) || crypto.randomUUID(),
createdAt: new Date(),
Expand Down