Skip to content

Commit df49bdf

Browse files
committed
add LlamaCloudConfig model
1 parent b711c02 commit df49bdf

File tree

1 file changed

+61
-32
lines changed
  • templates/components/vectordbs/python/llamacloud

1 file changed

+61
-32
lines changed
Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,71 @@
11
import logging
22
import os
3-
from typing import Dict, Optional
3+
from typing import Optional
44

55
from llama_index.core.callbacks import CallbackManager
66
from llama_index.core.ingestion.api_utils import (
77
get_client as llama_cloud_get_client,
88
)
99
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
10-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel, Field, validator
1111

1212
logger = logging.getLogger("uvicorn")
1313

1414

15+
class LlamaCloudConfig(BaseModel):
16+
# Private attributes
17+
api_key: str = Field(
18+
default=os.getenv("LLAMA_CLOUD_API_KEY"),
19+
exclude=True, # Exclude from the model representation
20+
)
21+
base_url: Optional[str] = Field(
22+
default=os.getenv("LLAMA_CLOUD_BASE_URL"),
23+
exclude=True,
24+
)
25+
organization_id: Optional[str] = Field(
26+
default=os.getenv("LLAMA_CLOUD_ORGANIZATION_ID"),
27+
exclude=True,
28+
)
29+
# Configuration attributes, can be set by the user
30+
pipeline: str = Field(
31+
description="The name of the pipeline to use",
32+
default=os.getenv("LLAMA_CLOUD_INDEX_NAME"),
33+
)
34+
project: str = Field(
35+
description="The name of the LlamaCloud project",
36+
default=os.getenv("LLAMA_CLOUD_PROJECT_NAME"),
37+
)
38+
39+
# Validate and throw error if the env variables are not set before starting the app
40+
@validator("pipeline", "project", "api_key", pre=True, always=True)
41+
@classmethod
42+
def validate_env_vars(cls, value):
43+
if value is None:
44+
raise ValueError(
45+
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
46+
" to your environment variables or config them in .env file"
47+
)
48+
return value
49+
50+
def to_index_kwargs(self) -> dict:
51+
return {
52+
"name": self.pipeline,
53+
"project_name": self.project,
54+
"api_key": self.api_key,
55+
"base_url": self.base_url,
56+
"organization_id": self.organization_id,
57+
}
58+
59+
def to_client_kwargs(self) -> dict:
60+
return {
61+
"api_key": self.api_key,
62+
"base_url": self.base_url,
63+
}
64+
65+
1566
class IndexConfig(BaseModel):
16-
llama_cloud_pipeline_config: Optional[Dict] = Field(
17-
default=None,
67+
llama_cloud_pipeline_config: LlamaCloudConfig = Field(
68+
default=LlamaCloudConfig(),
1869
alias="llamaCloudPipeline",
1970
)
2071
callback_manager: Optional[CallbackManager] = Field(
@@ -25,36 +76,14 @@ class IndexConfig(BaseModel):
2576
def get_index(config: IndexConfig = None):
2677
if config is None:
2778
config = IndexConfig()
28-
name = config.llama_cloud_pipeline_config.get(
29-
"pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME")
30-
)
31-
project_name = config.llama_cloud_pipeline_config.get(
32-
"project", os.getenv("LLAMA_CLOUD_PROJECT_NAME")
33-
)
34-
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
35-
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
36-
organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID")
37-
38-
if name is None or project_name is None or api_key is None:
39-
raise ValueError(
40-
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
41-
" to your environment variables or config them in .env file"
42-
)
43-
44-
index = LlamaCloudIndex(
45-
name=name,
46-
project_name=project_name,
47-
api_key=api_key,
48-
base_url=base_url,
49-
organization_id=organization_id,
50-
callback_manager=config.callback_manager,
51-
)
79+
index_kwargs = config.llama_cloud_pipeline_config.to_index_kwargs()
80+
index_kwargs["callback_manager"] = config.callback_manager
81+
82+
index = LlamaCloudIndex(**index_kwargs)
5283

5384
return index
5485

5586

5687
def get_client():
57-
return llama_cloud_get_client(
58-
os.getenv("LLAMA_CLOUD_API_KEY"),
59-
os.getenv("LLAMA_CLOUD_BASE_URL"),
60-
)
88+
config = LlamaCloudConfig()
89+
return llama_cloud_get_client(**config.to_client_kwargs())

0 commit comments

Comments
 (0)