@@ -6,7 +6,7 @@ import { assetRelocator, copy } from "../helpers/copy";
6
6
import { callPackageManager } from "../helpers/install" ;
7
7
import { templatesDir } from "./dir" ;
8
8
import { PackageManager } from "./get-pkg-manager" ;
9
- import { InstallTemplateArgs } from "./types" ;
9
+ import { InstallTemplateArgs , ModelProvider , TemplateVectorDB } from "./types" ;
10
10
11
11
/**
12
12
* Install a LlamaIndex internal template to a given `root` directory.
@@ -27,6 +27,7 @@ export const installTSTemplate = async ({
27
27
dataSources,
28
28
useLlamaParse,
29
29
useCase,
30
+ modelConfig,
30
31
} : InstallTemplateArgs & { backend : boolean } ) => {
31
32
console . log ( bold ( `Using ${ packageManager } .` ) ) ;
32
33
@@ -181,6 +182,12 @@ export const installTSTemplate = async ({
181
182
cwd : path . join ( compPath , "loaders" , "typescript" , loaderFolder ) ,
182
183
} ) ;
183
184
185
+ // copy provider settings
186
+ await copy ( "**" , enginePath , {
187
+ parents : true ,
188
+ cwd : path . join ( compPath , "providers" , "typescript" , modelConfig . provider ) ,
189
+ } ) ;
190
+
184
191
// Select and copy engine code based on data sources and tools
185
192
let engine ;
186
193
tools = tools ?? [ ] ;
@@ -239,6 +246,8 @@ export const installTSTemplate = async ({
239
246
ui,
240
247
observability,
241
248
vectorDb,
249
+ backend,
250
+ modelConfig,
242
251
} ) ;
243
252
244
253
if (
@@ -249,6 +258,68 @@ export const installTSTemplate = async ({
249
258
}
250
259
} ;
251
260
261
+ const providerDependencies : {
262
+ [ key in ModelProvider ] ?: Record < string , string > ;
263
+ } = {
264
+ openai : {
265
+ "@llamaindex/openai" : "^0.1.52" ,
266
+ } ,
267
+ gemini : {
268
+ "@llamaindex/google" : "^0.0.7" ,
269
+ } ,
270
+ ollama : {
271
+ "@llamaindex/ollama" : "^0.0.40" ,
272
+ } ,
273
+ mistral : {
274
+ "@llamaindex/mistral" : "^0.0.5" ,
275
+ } ,
276
+ "azure-openai" : {
277
+ "@llamaindex/openai" : "^0.1.52" ,
278
+ } ,
279
+ groq : {
280
+ "@llamaindex/groq" : "^0.0.51" ,
281
+ "@llamaindex/huggingface" : "^0.0.36" , // groq uses huggingface as default embedding model
282
+ } ,
283
+ anthropic : {
284
+ "@llamaindex/anthropic" : "^0.1.0" ,
285
+ "@llamaindex/huggingface" : "^0.0.36" , // anthropic uses huggingface as default embedding model
286
+ } ,
287
+ } ;
288
+
289
+ const vectorDbDependencies : Record < TemplateVectorDB , Record < string , string > > = {
290
+ astra : {
291
+ "@llamaindex/astra" : "^0.0.5" ,
292
+ } ,
293
+ chroma : {
294
+ "@llamaindex/chroma" : "^0.0.5" ,
295
+ } ,
296
+ llamacloud : { } ,
297
+ milvus : {
298
+ "@zilliz/milvus2-sdk-node" : "^2.4.6" ,
299
+ "@llamaindex/milvus" : "^0.1.0" ,
300
+ } ,
301
+ mongo : {
302
+ mongodb : "6.7.0" ,
303
+ "@llamaindex/mongodb" : "^0.0.5" ,
304
+ } ,
305
+ none : { } ,
306
+ pg : {
307
+ pg : "^8.12.0" ,
308
+ pgvector : "^0.2.0" ,
309
+ "@llamaindex/postgres" : "^0.0.33" ,
310
+ } ,
311
+ pinecone : {
312
+ "@llamaindex/pinecone" : "^0.0.5" ,
313
+ } ,
314
+ qdrant : {
315
+ "@qdrant/js-client-rest" : "^1.11.0" ,
316
+ "@llamaindex/qdrant" : "^0.1.0" ,
317
+ } ,
318
+ weaviate : {
319
+ "@llamaindex/weaviate" : "^0.0.5" ,
320
+ } ,
321
+ } ;
322
+
252
323
async function updatePackageJson ( {
253
324
root,
254
325
appName,
@@ -258,6 +329,8 @@ async function updatePackageJson({
258
329
ui,
259
330
observability,
260
331
vectorDb,
332
+ backend,
333
+ modelConfig,
261
334
} : Pick <
262
335
InstallTemplateArgs ,
263
336
| "root"
@@ -267,8 +340,10 @@ async function updatePackageJson({
267
340
| "ui"
268
341
| "observability"
269
342
| "vectorDb"
343
+ | "modelConfig"
270
344
> & {
271
345
relativeEngineDestPath : string ;
346
+ backend : boolean ;
272
347
} ) : Promise < any > {
273
348
const packageJsonFile = path . join ( root , "package.json" ) ;
274
349
const packageJson : any = JSON . parse (
@@ -308,32 +383,25 @@ async function updatePackageJson({
308
383
} ;
309
384
}
310
385
311
- if ( vectorDb === "pg" ) {
386
+ if ( backend ) {
312
387
packageJson . dependencies = {
313
388
...packageJson . dependencies ,
314
- pg : "^8.12.0" ,
315
- pgvector : "^0.2.0" ,
389
+ "@llamaindex/readers" : "^2.0.0" ,
316
390
} ;
317
- }
318
391
319
- if ( vectorDb === "qdrant" ) {
320
- packageJson . dependencies = {
321
- ...packageJson . dependencies ,
322
- "@qdrant/js-client-rest" : "^1.11.0" ,
323
- } ;
324
- }
325
- if ( vectorDb === "mongo" ) {
326
- packageJson . dependencies = {
327
- ...packageJson . dependencies ,
328
- mongodb : "^6.7.0" ,
329
- } ;
330
- }
392
+ if ( vectorDb && vectorDb in vectorDbDependencies ) {
393
+ packageJson . dependencies = {
394
+ ...packageJson . dependencies ,
395
+ ...vectorDbDependencies [ vectorDb ] ,
396
+ } ;
397
+ }
331
398
332
- if ( vectorDb === "milvus" ) {
333
- packageJson . dependencies = {
334
- ...packageJson . dependencies ,
335
- "@zilliz/milvus2-sdk-node" : "^2.4.6" ,
336
- } ;
399
+ if ( modelConfig . provider && modelConfig . provider in providerDependencies ) {
400
+ packageJson . dependencies = {
401
+ ...packageJson . dependencies ,
402
+ ...providerDependencies [ modelConfig . provider ] ,
403
+ } ;
404
+ }
337
405
}
338
406
339
407
if ( observability === "traceloop" ) {
0 commit comments