Skip to content

Commit be82fcb

Browse files
committed
Generlize RAG
1 parent d013181 commit be82fcb

File tree

11 files changed

+66
-32
lines changed

11 files changed

+66
-32
lines changed

src/lib/server/pdfSearch.ts renamed to src/lib/server/rag/pdfchat/rag.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
import type { PdfSearch } from "$lib/types/PdfChat";
21
import { createEmbeddings, findSimilarSentences } from "$lib/server/embeddings";
32
import type { Conversation } from "$lib/types/Conversation";
43
import type { MessageUpdate } from "$lib/types/MessageUpdate";
5-
import { downloadPdfEmbeddings } from "./files/downloadFile";
4+
import { downloadPdfEmbeddings } from "../../files/downloadFile";
65
import { Tensor } from "@xenova/transformers";
6+
import type { RAG } from "../RAG";
7+
import type { RagContext } from "$lib/types/rag";
78

89
// todo: embed the prompt, download the embeddings, serialize them, and find the closest sentences, and get their texts, lets go
9-
export async function runPdfSearch(
10+
async function runPdfSearch(
1011
conv: Conversation,
1112
prompt: string,
1213
updatePad: (upd: MessageUpdate) => void
1314
) {
14-
const pdfSearch: PdfSearch = {
15+
const pdfSearch: RagContext = {
1516
context: "",
17+
type: "pdfchat",
1618
createdAt: new Date(),
1719
updatedAt: new Date(),
1820
};
@@ -50,3 +52,8 @@ export async function runPdfSearch(
5052

5153
return pdfSearch;
5254
}
55+
56+
export const ragPdfchat: RAG = {
57+
type: "pdfchat",
58+
retrieveRagContext: runPdfSearch,
59+
}

src/lib/server/rag/rag.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import type { Conversation } from "$lib/types/Conversation";
2+
import type { MessageUpdate } from "$lib/types/MessageUpdate";
3+
import type { RagContext } from "$lib/types/rag";
4+
import { ragPdfchat } from "./pdfchat/rag";
5+
import { ragWebsearch } from "./websearch/rag";
6+
7+
type RetrieveRagContext<T=RagContext> = (conv: Conversation, prompt: string, updatePad: (upd: MessageUpdate) => void) => Promise<T>;
8+
9+
export type RAGType = "websearch" | "pdfchat";
10+
11+
export interface RAG {
12+
type: RAGType, // should be one of websearch or somethign, make it stronger typed
13+
retrieveRagContext: RetrieveRagContext,
14+
buildPrompt?: string,
15+
}
16+
17+
// list of all rags
18+
export const RAGs: {
19+
[Key in RAGType]: RAG;
20+
} = {
21+
websearch: ragWebsearch,
22+
pdfchat: ragPdfchat,
23+
};
24+
25+
export default RAGs;

src/lib/server/websearch/generateQuery.ts renamed to src/lib/server/rag/websearch/generateQuery.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { Message } from "$lib/types/Message";
22
import { format } from "date-fns";
3-
import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint";
3+
import { generateFromDefaultEndpoint } from "../../generateFromDefaultEndpoint";
44
import { WEBSEARCH_ALLOWLIST, WEBSEARCH_BLOCKLIST } from "$env/static/private";
55
import { z } from "zod";
66
import JSON5 from "json5";

src/lib/server/websearch/runWebSearch.ts renamed to src/lib/server/rag/websearch/rag.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
import { searchWeb } from "$lib/server/websearch/searchWeb";
1+
import { searchWeb } from "$lib/server/rag/websearch/searchWeb";
22
import type { Message } from "$lib/types/Message";
33
import type { WebSearch, WebSearchSource } from "$lib/types/WebSearch";
4-
import { generateQuery } from "$lib/server/websearch/generateQuery";
5-
import { parseWeb } from "$lib/server/websearch/parseWeb";
4+
import { generateQuery } from "$lib/server/rag/websearch/generateQuery";
5+
import { parseWeb } from "$lib/server/rag/websearch/parseWeb";
66
import { chunk } from "$lib/utils/chunk";
77
import { findSimilarSentences } from "$lib/server/sentenceSimilarity";
88
import type { Conversation } from "$lib/types/Conversation";
99
import type { MessageUpdate } from "$lib/types/MessageUpdate";
1010
import { getWebSearchProvider } from "./searchWeb";
1111
import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels";
12+
import type { RAG } from "../RAG";
1213

1314
const MAX_N_PAGES_SCRAPE = 10 as const;
1415
const MAX_N_PAGES_EMBED = 5 as const;
1516

1617
const DOMAIN_BLOCKLIST = ["youtube.com", "twitter.com"];
1718

18-
export async function runWebSearch(
19+
async function runWebSearch(
1920
conv: Conversation,
2021
prompt: string,
2122
updatePad: (upd: MessageUpdate) => void
@@ -25,6 +26,7 @@ export async function runWebSearch(
2526
})() satisfies Message[];
2627

2728
const webSearch: WebSearch = {
29+
type: "websearch",
2830
prompt: prompt,
2931
searchQuery: "",
3032
results: [],
@@ -130,3 +132,8 @@ export async function runWebSearch(
130132

131133
return webSearch;
132134
}
135+
136+
export const ragWebsearch: RAG = {
137+
type: "websearch",
138+
retrieveRagContext: runWebSearch,
139+
}

src/lib/server/websearch/searchWeb.ts renamed to src/lib/server/rag/websearch/searchWeb.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import type { YouWebSearch } from "../../types/WebSearch";
2-
import { WebSearchProvider } from "../../types/WebSearch";
1+
import type { YouWebSearch } from "../../../types/WebSearch";
2+
import { WebSearchProvider } from "../../../types/WebSearch";
33
import {
44
SERPAPI_KEY,
55
SERPER_API_KEY,

src/lib/types/PdfChat.ts

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
import type { ObjectId } from "mongodb";
2-
import type { Timestamps } from "./Timestamps";
3-
4-
export interface PdfSearch extends Timestamps {
5-
_id?: ObjectId;
6-
context: string;
7-
}
8-
91
/* eslint-disable no-shadow */
102
export enum PdfUploadStatus {
113
Ready = "Ready",

src/lib/types/WebSearch.ts

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
1-
import type { ObjectId } from "mongodb";
2-
import type { Conversation } from "./Conversation";
3-
import type { Timestamps } from "./Timestamps";
4-
5-
export interface WebSearch extends Timestamps {
6-
_id?: ObjectId;
7-
convId?: Conversation["_id"];
1+
import type { RagContext } from "./rag";
82

3+
export interface WebSearch extends RagContext {
94
prompt: string;
10-
115
searchQuery: string;
126
results: WebSearchSource[];
13-
context: string;
147
contextSources: WebSearchSource[];
158
}
169

src/lib/types/rag.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import type { ObjectId } from "mongodb";
2+
import type { Conversation } from "./Conversation";
3+
import type { Timestamps } from "./Timestamps";
4+
import type { RAGType } from "$lib/server/rag/RAG";
5+
6+
export interface RagContext extends Timestamps {
7+
_id?: ObjectId;
8+
convId?: Conversation["_id"];
9+
type: RAGType
10+
context: string;
11+
}

src/routes/conversation/[id]/+server.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@ import { error } from "@sveltejs/kit";
88
import { ObjectId } from "mongodb";
99
import { z } from "zod";
1010
import type { MessageUpdate } from "$lib/types/MessageUpdate";
11-
import { runWebSearch } from "$lib/server/websearch/runWebSearch";
12-
import { runPdfSearch } from "$lib/server/pdfSearch";
1311
import type { WebSearch } from "$lib/types/WebSearch";
1412
import type { PdfSearch } from "$lib/types/PdfChat";
1513
import { abortedGenerations } from "$lib/server/abortedGenerations";
1614
import { summarize } from "$lib/server/summarize";
1715
import { uploadImgFile } from "$lib/server/files/uploadFile";
1816
import sizeof from "image-size";
17+
import RAGs from "$lib/server/rag/RAG";
1918

2019
export async function POST({ request, locals, params, getClientAddress }) {
2120
const id = z.string().parse(params.id);
@@ -237,15 +236,15 @@ export async function POST({ request, locals, params, getClientAddress }) {
237236
let webSearchResults: WebSearch | undefined;
238237

239238
if (webSearch) {
240-
webSearchResults = await runWebSearch(conv, newPrompt, update);
239+
webSearchResults = await RAGs["websearch"].retrieveRagContext(conv, newPrompt, update);
241240
}
242241

243242
messages[messages.length - 1].webSearch = webSearchResults;
244243

245244
let pdfSearchResults: PdfSearch | undefined;
246245
const pdfSearch = await collections.files.findOne({ filename: `${convId.toString()}-pdf` });
247246
if (pdfSearch) {
248-
pdfSearchResults = await runPdfSearch(conv, newPrompt, update);
247+
pdfSearchResults = await RAGs["websearch"].retrieveRagContext(conv, newPrompt, update);
249248
}
250249

251250
messages[messages.length - 1].pdfSearch = pdfSearchResults;

0 commit comments

Comments
 (0)