Skip to content

Commit fea9360

Browse files
author
Mishig
authored
Generlize RAG (#689)
* Generlize RAG * wip * fix casting
1 parent d013181 commit fea9360

24 files changed

+216
-282
lines changed

src/lib/buildPrompt.ts

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,32 @@
11
import type { BackendModel } from "./server/models";
22
import type { Message } from "./types/Message";
3-
import { format } from "date-fns";
4-
import type { WebSearch } from "./types/WebSearch";
5-
import type { PdfSearch } from "./types/PdfChat";
63
import { downloadImgFile } from "./server/files/downloadFile";
74
import type { Conversation } from "./types/Conversation";
5+
import RAGs from "./server/rag/rag";
6+
import type { RagContext } from "./types/rag";
7+
8+
export type BuildPromptMessage = Pick<Message, "from" | "content" | "files">;
89

910
interface buildPromptOptions {
10-
messages: Pick<Message, "from" | "content" | "files">[];
11+
messages: BuildPromptMessage[];
1112
id?: Conversation["_id"];
1213
model: BackendModel;
1314
locals?: App.Locals;
14-
webSearch?: WebSearch;
15-
pdfSearch?: PdfSearch;
15+
ragContext?: RagContext;
1616
preprompt?: string;
1717
files?: File[];
1818
}
1919

2020
export async function buildPrompt({
2121
messages,
2222
model,
23-
webSearch,
24-
pdfSearch,
23+
ragContext,
2524
preprompt,
2625
id,
2726
}: buildPromptOptions): Promise<string> {
28-
if (webSearch && webSearch.context) {
29-
const lastMsg = messages.slice(-1)[0];
30-
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
31-
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
32-
33-
const previousQuestions =
34-
previousUserMessages.length > 0
35-
? `Previous questions: \n${previousUserMessages
36-
.map(({ content }) => `- ${content}`)
37-
.join("\n")}`
38-
: "";
39-
const currentDate = format(new Date(), "MMMM d, yyyy");
40-
messages = [
41-
...messagesWithoutLastUsrMsg,
42-
{
43-
from: "user",
44-
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
45-
=====================
46-
${webSearch.context}
47-
=====================
48-
${previousQuestions}
49-
Answer the question: ${lastMsg.content}
50-
`,
51-
},
52-
];
53-
} else if (pdfSearch && pdfSearch.context) {
54-
const lastMsg = messages.slice(-1)[0];
55-
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
56-
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
57-
58-
const previousQuestions =
59-
previousUserMessages.length > 0
60-
? `Previous questions: \n${previousUserMessages
61-
.map(({ content }) => `- ${content}`)
62-
.join("\n")}`
63-
: "";
64-
65-
messages = [
66-
...messagesWithoutLastUsrMsg,
67-
{
68-
from: "user",
69-
content: `Below are the information I extracted from a PDF file that might be useful:
70-
=====================
71-
${pdfSearch.context}
72-
=====================
73-
${previousQuestions}
74-
Answer the question: ${lastMsg.content}
75-
`,
76-
},
77-
];
27+
if (ragContext) {
28+
const { type: ragType } = ragContext;
29+
messages = RAGs[ragType].buildPrompt(messages, ragContext);
7830
}
7931

8032
// section to handle potential files input

src/lib/components/OpenPdfSearchResults.svelte

Lines changed: 0 additions & 114 deletions
This file was deleted.

src/lib/components/OpenWebSearchResults.svelte renamed to src/lib/components/OpenRAGResults.svelte

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
<script lang="ts">
2-
import type { WebSearchUpdate } from "$lib/types/MessageUpdate";
2+
import type { RAGUpdate } from "$lib/types/MessageUpdate";
3+
import type { RAGType } from "$lib/types/rag";
34
import CarbonCaretRight from "~icons/carbon/caret-right";
45
56
import CarbonCheckmark from "~icons/carbon/checkmark-filled";
@@ -9,11 +10,22 @@
910
1011
export let loading = false;
1112
export let classNames = "";
12-
export let webSearchMessages: WebSearchUpdate[] = [];
13+
export let ragUpdates: RAGUpdate[] = [];
14+
15+
const TITLE_MAPPING: Record<RAGType, string> = {
16+
webSearch: "Web search",
17+
pdfChat: "PDF Chat",
18+
};
1319
1420
let detailsOpen: boolean;
1521
let error: boolean;
16-
$: error = webSearchMessages[webSearchMessages.length - 1]?.messageType === "error";
22+
$: error = ragUpdates[ragUpdates.length - 1]?.messageType === "error";
23+
24+
$: ragType = ragUpdates[0].type;
25+
26+
$: loading =
27+
ragUpdates.length > 0 &&
28+
!["sources", "done"].includes(ragUpdates[ragUpdates.length - 1].messageType);
1729
</script>
1830

1931
<details
@@ -31,21 +43,21 @@
3143
<CarbonCheckmark class="my-auto text-gray-500" />
3244
{/if}
3345
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
34-
Web search
46+
{TITLE_MAPPING[ragType]}
3547
</span>
3648
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
3749
<CarbonCaretRight />
3850
</div>
3951
</summary>
4052

4153
<div class="content px-5 pb-5 pt-4">
42-
{#if webSearchMessages.length === 0}
54+
{#if ragUpdates.length === 0}
4355
<div class="mx-auto w-fit">
4456
<EosIconsLoading class="mb-3 h-4 w-4" />
4557
</div>
4658
{:else}
4759
<ol>
48-
{#each webSearchMessages as message}
60+
{#each ragUpdates as message}
4961
{#if message.messageType === "update"}
5062
<li class="group border-l pb-6 last:!border-transparent last:pb-0 dark:border-gray-800">
5163
<div class="flex items-start">

src/lib/components/chat/ChatMessage.svelte

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
1717
import type { Model } from "$lib/types/Model";
1818
19-
import OpenWebSearchResults from "../OpenWebSearchResults.svelte";
20-
import OpenPdfSearchResults from "../OpenPdfSearchResults.svelte";
21-
import type { RAGUpdate, WebSearchUpdate, PdfSearchUpdate } from "$lib/types/MessageUpdate";
19+
import OpenRAGResults from "../OpenRAGResults.svelte";
20+
import type { RAGUpdate } from "$lib/types/MessageUpdate";
21+
import { ragTypes } from "$lib/types/rag";
2222
2323
function sanitizeMd(md: string) {
2424
let ret = md
@@ -61,6 +61,7 @@
6161
let loadingEl: IconLoading;
6262
let pendingTimeout: ReturnType<typeof setTimeout>;
6363
let isCopied = false;
64+
let ragIsLoading = false;
6465
6566
const renderer = new marked.Renderer();
6667
// For code blocks with simple backticks
@@ -107,29 +108,17 @@
107108
}
108109
});
109110
110-
let searchUpdates: WebSearchUpdate[] = [];
111+
let ragUpdates: RAGUpdate[] = [];
111112
112-
$: searchUpdates = ((RAGMessages.filter(({ type }) => type === "webSearch").length > 0
113-
? RAGMessages.filter(({ type }) => type === "webSearch")
114-
: message.updates?.filter(({ type }) => type === "webSearch")) ?? []) as WebSearchUpdate[];
115-
116-
let pdfUpdates: PdfSearchUpdate[] = [];
117-
118-
$: pdfUpdates = ((RAGMessages.filter(({ type }) => type === "pdfSearch").length > 0
119-
? RAGMessages.filter(({ type }) => type === "pdfSearch")
120-
: message.updates?.filter(({ type }) => type === "pdfSearch")) ?? []) as PdfSearchUpdate[];
113+
$: ragUpdates = ((RAGMessages.length > 0
114+
? RAGMessages
115+
: message.updates?.filter(({ type }) => ragTypes.includes(type))) ?? []) as RAGUpdate[];
121116
122117
$: downloadLink =
123118
message.from === "user" ? `${$page.url.pathname}/message/${message.id}/prompt` : undefined;
124119
125-
let webSearchIsDone = true;
126-
127-
$: webSearchIsDone =
128-
searchUpdates.length > 0 && searchUpdates[searchUpdates.length - 1].messageType === "sources";
129-
130120
$: webSearchSources =
131-
searchUpdates &&
132-
searchUpdates?.filter(({ messageType }) => messageType === "sources")?.[0]?.sources;
121+
ragUpdates && ragUpdates?.filter(({ messageType }) => messageType === "sources")?.[0]?.sources;
133122
134123
$: if (isCopied) {
135124
setTimeout(() => {
@@ -153,21 +142,14 @@
153142
<div
154143
class="relative min-h-[calc(2rem+theme(spacing[3.5])*2)] min-w-[60px] break-words rounded-2xl border border-gray-100 bg-gradient-to-br from-gray-50 px-5 py-3.5 text-gray-600 prose-pre:my-2 dark:border-gray-800 dark:from-gray-800/40 dark:text-gray-300"
155144
>
156-
{#if searchUpdates && searchUpdates.length > 0}
157-
<OpenWebSearchResults
158-
classNames={tokens.length ? "mb-3.5" : ""}
159-
webSearchMessages={searchUpdates}
160-
loading={!(searchUpdates[searchUpdates.length - 1]?.messageType === "sources")}
161-
/>
162-
{/if}
163-
{#if pdfUpdates && pdfUpdates.length > 0}
164-
<OpenPdfSearchResults
145+
{#if ragUpdates && ragUpdates.length > 0}
146+
<OpenRAGResults
165147
classNames={tokens.length ? "mb-3.5" : ""}
166-
pdfSearchMessages={pdfUpdates}
167-
loading={!(pdfUpdates[pdfUpdates.length - 1]?.messageType === "done")}
148+
{ragUpdates}
149+
bind:loading={ragIsLoading}
168150
/>
169151
{/if}
170-
{#if !message.content && (webSearchIsDone || (RAGMessages && RAGMessages.length === 0))}
152+
{#if !message.content && (!ragIsLoading || (RAGMessages && RAGMessages.length === 0))}
171153
<IconLoading />
172154
{/if}
173155

src/lib/server/endpoints/aws/endpointAws.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ export async function endpointAws(
3939
return async ({ conversation }) => {
4040
const prompt = await buildPrompt({
4141
messages: conversation.messages,
42-
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
43-
pdfSearch: conversation.messages[conversation.messages.length - 1].pdfSearch,
42+
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
4443
preprompt: conversation.preprompt,
4544
model,
4645
});

src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ export function endpointLlamacpp(
2222
return async ({ conversation }) => {
2323
const prompt = await buildPrompt({
2424
messages: conversation.messages,
25-
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
26-
pdfSearch: conversation.messages[conversation.messages.length - 1].pdfSearch,
25+
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
2726
preprompt: conversation.preprompt,
2827
model,
2928
});

src/lib/server/endpoints/ollama/endpointOllama.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
1717
return async ({ conversation }) => {
1818
const prompt = await buildPrompt({
1919
messages: conversation.messages,
20-
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
21-
pdfSearch: conversation.messages[conversation.messages.length - 1].pdfSearch,
20+
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
2221
preprompt: conversation.preprompt,
2322
model,
2423
});

src/lib/server/endpoints/openai/endpointOai.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ export async function endpointOai(
4040
model: model.id ?? model.name,
4141
prompt: await buildPrompt({
4242
messages: conversation.messages,
43-
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
44-
pdfSearch: conversation.messages[conversation.messages.length - 1].pdfSearch,
43+
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
4544
preprompt: conversation.preprompt,
4645
model,
4746
}),

src/lib/server/endpoints/tgi/endpointTgi.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
1818
return async ({ conversation }) => {
1919
const prompt = await buildPrompt({
2020
messages: conversation.messages,
21-
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
22-
pdfSearch: conversation.messages[conversation.messages.length - 1].pdfSearch,
21+
ragContext: conversation.messages[conversation.messages.length - 1].ragContext,
2322
preprompt: conversation.preprompt,
2423
model,
2524
id: conversation._id,

0 commit comments

Comments
 (0)