1
1
import logging
2
2
import os
3
- from typing import Dict , Optional
3
+ from typing import Optional
4
4
5
5
from llama_index .core .callbacks import CallbackManager
6
6
from llama_index .core .ingestion .api_utils import (
7
7
get_client as llama_cloud_get_client ,
8
8
)
9
9
from llama_index .indices .managed .llama_cloud import LlamaCloudIndex
10
- from pydantic import BaseModel , Field
10
+ from pydantic import BaseModel , Field , validator
11
11
12
12
logger = logging .getLogger ("uvicorn" )
13
13
14
14
15
+ class LlamaCloudConfig (BaseModel ):
16
+ # Private attributes
17
+ api_key : str = Field (
18
+ default = os .getenv ("LLAMA_CLOUD_API_KEY" ),
19
+ exclude = True , # Exclude from the model representation
20
+ )
21
+ base_url : Optional [str ] = Field (
22
+ default = os .getenv ("LLAMA_CLOUD_BASE_URL" ),
23
+ exclude = True ,
24
+ )
25
+ organization_id : Optional [str ] = Field (
26
+ default = os .getenv ("LLAMA_CLOUD_ORGANIZATION_ID" ),
27
+ exclude = True ,
28
+ )
29
+ # Configuration attributes, can be set by the user
30
+ pipeline : str = Field (
31
+ description = "The name of the pipeline to use" ,
32
+ default = os .getenv ("LLAMA_CLOUD_INDEX_NAME" ),
33
+ )
34
+ project : str = Field (
35
+ description = "The name of the LlamaCloud project" ,
36
+ default = os .getenv ("LLAMA_CLOUD_PROJECT_NAME" ),
37
+ )
38
+
39
+ # Validate and throw error if the env variables are not set before starting the app
40
+ @validator ("pipeline" , "project" , "api_key" , pre = True , always = True )
41
+ @classmethod
42
+ def validate_env_vars (cls , value ):
43
+ if value is None :
44
+ raise ValueError (
45
+ "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
46
+ " to your environment variables or config them in .env file"
47
+ )
48
+ return value
49
+
50
+ def to_index_kwargs (self ) -> dict :
51
+ return {
52
+ "name" : self .pipeline ,
53
+ "project_name" : self .project ,
54
+ "api_key" : self .api_key ,
55
+ "base_url" : self .base_url ,
56
+ "organization_id" : self .organization_id ,
57
+ }
58
+
59
+ def to_client_kwargs (self ) -> dict :
60
+ return {
61
+ "api_key" : self .api_key ,
62
+ "base_url" : self .base_url ,
63
+ }
64
+
65
+
15
66
class IndexConfig (BaseModel ):
16
- llama_cloud_pipeline_config : Optional [ Dict ] = Field (
17
- default = None ,
67
+ llama_cloud_pipeline_config : LlamaCloudConfig = Field (
68
+ default = LlamaCloudConfig () ,
18
69
alias = "llamaCloudPipeline" ,
19
70
)
20
71
callback_manager : Optional [CallbackManager ] = Field (
@@ -25,36 +76,14 @@ class IndexConfig(BaseModel):
25
76
def get_index (config : IndexConfig = None ):
26
77
if config is None :
27
78
config = IndexConfig ()
28
- name = config .llama_cloud_pipeline_config .get (
29
- "pipeline" , os .getenv ("LLAMA_CLOUD_INDEX_NAME" )
30
- )
31
- project_name = config .llama_cloud_pipeline_config .get (
32
- "project" , os .getenv ("LLAMA_CLOUD_PROJECT_NAME" )
33
- )
34
- api_key = os .getenv ("LLAMA_CLOUD_API_KEY" )
35
- base_url = os .getenv ("LLAMA_CLOUD_BASE_URL" )
36
- organization_id = os .getenv ("LLAMA_CLOUD_ORGANIZATION_ID" )
37
-
38
- if name is None or project_name is None or api_key is None :
39
- raise ValueError (
40
- "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
41
- " to your environment variables or config them in .env file"
42
- )
43
-
44
- index = LlamaCloudIndex (
45
- name = name ,
46
- project_name = project_name ,
47
- api_key = api_key ,
48
- base_url = base_url ,
49
- organization_id = organization_id ,
50
- callback_manager = config .callback_manager ,
51
- )
79
+ index_kwargs = config .llama_cloud_pipeline_config .to_index_kwargs ()
80
+ index_kwargs ["callback_manager" ] = config .callback_manager
81
+
82
+ index = LlamaCloudIndex (** index_kwargs )
52
83
53
84
return index
54
85
55
86
56
87
def get_client ():
57
- return llama_cloud_get_client (
58
- os .getenv ("LLAMA_CLOUD_API_KEY" ),
59
- os .getenv ("LLAMA_CLOUD_BASE_URL" ),
60
- )
88
+ config = LlamaCloudConfig ()
89
+ return llama_cloud_get_client (** config .to_client_kwargs ())
0 commit comments