From 5681a91425048db9a3d2b8f5959c7770646106cf Mon Sep 17 00:00:00 2001 From: kaustubh-darekar Date: Mon, 16 Dec 2024 06:21:44 +0000 Subject: [PATCH] Error handling for model format in backend & frontend env --- backend/src/llm.py | 210 ++++++++++++++++++++------------------ backend/src/ragas_eval.py | 1 + 2 files changed, 113 insertions(+), 98 deletions(-) diff --git a/backend/src/llm.py b/backend/src/llm.py index f19648ed6..381a38a68 100644 --- a/backend/src/llm.py +++ b/backend/src/llm.py @@ -16,95 +16,106 @@ def get_llm(model: str): """Retrieve the specified language model based on the model name.""" - env_key = "LLM_MODEL_CONFIG_" + model + model = model.lower().strip() + env_key = f"LLM_MODEL_CONFIG_{model}" env_value = os.environ.get(env_key) - logging.info("Model: {}".format(env_key)) + + if not env_value: + err = f"Environment variable '{env_key}' is not defined as per format or missing" + logging.error(err) + raise Exception(err) - if "gemini" in model: - model_name = env_value - credentials, project_id = google.auth.default() - llm = ChatVertexAI( - model_name=model_name, - #convert_system_message_to_human=True, - credentials=credentials, - project=project_id, - temperature=0, - safety_settings={ - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - }, - ) - elif "openai" in model: - model_name, api_key = env_value.split(",") - llm = ChatOpenAI( - api_key=api_key, - model=model_name, - temperature=0, - ) + logging.info("Model: {}".format(env_key)) + try: + if "gemini" in model: + model_name = env_value + credentials, project_id = google.auth.default() + llm = ChatVertexAI( + model_name=model_name, + #convert_system_message_to_human=True, + credentials=credentials, + project=project_id, + temperature=0, + safety_settings={ + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + }, + ) + elif "openai" in model: + model_name, api_key = env_value.split(",") + llm = ChatOpenAI( + api_key=api_key, + model=model_name, + temperature=0, + ) - elif "azure" in model: - model_name, api_endpoint, api_key, api_version = env_value.split(",") - llm = AzureChatOpenAI( - api_key=api_key, - azure_endpoint=api_endpoint, - azure_deployment=model_name, # takes precedence over model parameter - api_version=api_version, - temperature=0, - max_tokens=None, - timeout=None, - ) + elif "azure" in model: + model_name, api_endpoint, api_key, api_version = env_value.split(",") + llm = AzureChatOpenAI( + api_key=api_key, + azure_endpoint=api_endpoint, + azure_deployment=model_name, # takes precedence over model parameter + api_version=api_version, + temperature=0, + max_tokens=None, + timeout=None, + ) - elif "anthropic" in model: - model_name, api_key = env_value.split(",") - llm = ChatAnthropic( - api_key=api_key, model=model_name, temperature=0, timeout=None - ) + elif "anthropic" in model: + model_name, api_key = env_value.split(",") + llm = ChatAnthropic( + api_key=api_key, model=model_name, temperature=0, timeout=None + ) - elif "fireworks" in model: - model_name, api_key = env_value.split(",") - llm = ChatFireworks(api_key=api_key, model=model_name) - - elif "groq" in model: - model_name, base_url, api_key = env_value.split(",") - llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0) - - elif "bedrock" in model: - model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",") - bedrock_client = boto3.client( - service_name="bedrock-runtime", - region_name=region_name, - aws_access_key_id=aws_access_key, - aws_secret_access_key=aws_secret_key, - ) + elif "fireworks" in model: + model_name, api_key = env_value.split(",") + llm = ChatFireworks(api_key=api_key, model=model_name) + + elif "groq" in model: + model_name, base_url, api_key = env_value.split(",") + llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0) + + elif "bedrock" in model: + model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",") + bedrock_client = boto3.client( + service_name="bedrock-runtime", + region_name=region_name, + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key, + ) - llm = ChatBedrock( - client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0) - ) + llm = ChatBedrock( + client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0) + ) - elif "ollama" in model: - model_name, base_url = env_value.split(",") - llm = ChatOllama(base_url=base_url, model=model_name) + elif "ollama" in model: + model_name, base_url = env_value.split(",") + llm = ChatOllama(base_url=base_url, model=model_name) - elif "diffbot" in model: - #model_name = "diffbot" - model_name, api_key = env_value.split(",") - llm = DiffbotGraphTransformer( - diffbot_api_key=api_key, - extract_types=["entities", "facts"], - ) - - else: - model_name, api_endpoint, api_key = env_value.split(",") - llm = ChatOpenAI( - api_key=api_key, - base_url=api_endpoint, - model=model_name, - temperature=0, - ) - + elif "diffbot" in model: + #model_name = "diffbot" + model_name, api_key = env_value.split(",") + llm = DiffbotGraphTransformer( + diffbot_api_key=api_key, + extract_types=["entities", "facts"], + ) + + else: + model_name, api_endpoint, api_key = env_value.split(",") + llm = ChatOpenAI( + api_key=api_key, + base_url=api_endpoint, + model=model_name, + temperature=0, + ) + except Exception as e: + err = f"Error while creating LLM '{model}': {str(e)}" + logging.error(err) + raise Exception(err) + logging.info(f"Model created - Model Version: {model}") return llm, model_name @@ -179,21 +190,24 @@ async def get_graph_document_list( async def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship): - - llm, model_name = get_llm(model) - combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list) - #combined_chunk_document_list = get_chunk_id_as_doc_metadata(chunkId_chunkDoc_list) - - if allowedNodes is None or allowedNodes=="": - allowedNodes =[] - else: - allowedNodes = allowedNodes.split(',') - if allowedRelationship is None or allowedRelationship=="": - allowedRelationship=[] - else: - allowedRelationship = allowedRelationship.split(',') + try: + llm, model_name = get_llm(model) + combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list) - graph_document_list = await get_graph_document_list( - llm, combined_chunk_document_list, allowedNodes, allowedRelationship - ) - return graph_document_list + if allowedNodes is None or allowedNodes=="": + allowedNodes =[] + else: + allowedNodes = allowedNodes.split(',') + if allowedRelationship is None or allowedRelationship=="": + allowedRelationship=[] + else: + allowedRelationship = allowedRelationship.split(',') + + graph_document_list = await get_graph_document_list( + llm, combined_chunk_document_list, allowedNodes, allowedRelationship + ) + return graph_document_list + except Exception as e: + err = f"Error during extracting graph with llm: {e}" + logging.error(err) + raise diff --git a/backend/src/ragas_eval.py b/backend/src/ragas_eval.py index 29f7e026c..251ab71c0 100644 --- a/backend/src/ragas_eval.py +++ b/backend/src/ragas_eval.py @@ -17,6 +17,7 @@ load_dotenv() EMBEDDING_MODEL = os.getenv("RAGAS_EMBEDDING_MODEL") +logging.info(f"Loading embedding model '{EMBEDDING_MODEL}' for ragas evaluation") EMBEDDING_FUNCTION, _ = load_embedding_model(EMBEDDING_MODEL) def get_ragas_metrics(question: str, context: list, answer: list, model: str):