Skip to content

Commit 6f90473

Browse files
kartikpersistentvasanthasaikallurikaustubh-darekara-s-poorna
authored
Raga's Evaluation Metrics (#787)
* added Multi modes selection * ragas eval * added response * multimodes state mangement * fix: state handling of chat details of the default mode * Added the ChatModeSwitch Component * modes switch statemangement * added the chatmodes switch in both view * removed the copied text * Handled the error scenario * fix: speech issue between modes * ragas evaluation metric show * Output return type changed * fix: Handled activespeech speech and othermessage modes switch * used requestanimationframe instead of setTimeOut * removed the commented code * Added ragas to requirements * Integrated the metric api * ragas response updated, llm list updated * resolved syntax error in score * Added the Metrics Table * fix: Long text UI Issue * code optimization for evaluation * added the download button for downloading the info * key name change * Optimized the downloadClickHandler --------- Co-authored-by: vasanthasaikalluri <[email protected]> Co-authored-by: kaustubh-darekar <[email protected]> Co-authored-by: a-s-poorna <[email protected]>
1 parent 3b5fdeb commit 6f90473

File tree

9 files changed

+550
-55
lines changed

9 files changed

+550
-55
lines changed

backend/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,5 @@ PyMuPDF==1.24.5
179179
pypandoc==1.13
180180
graphdatascience==1.10
181181
Secweb==1.11.0
182+
ragas==0.1.14
182183

backend/score.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from Secweb.XContentTypeOptions import XContentTypeOptions
3434
from Secweb.XFrameOptions import XFrame
3535

36+
from src.ragas_eval import *
37+
3638
logger = CustomLogger()
3739
CHUNK_DIR = os.path.join(os.path.dirname(__file__), "chunks")
3840
MERGED_DIR = os.path.join(os.path.dirname(__file__), "merged_files")
@@ -706,7 +708,23 @@ async def retry_processing(uri=Form(), userName=Form(), password=Form(), databas
706708
logging.exception(f'{error_message}')
707709
return create_api_response(job_status, message=message, error=error_message)
708710
finally:
709-
gc.collect()
711+
gc.collect()
712+
713+
@app.post('/metric')
714+
async def calculate_metric(question=Form(),context=Form(),answer=Form(),model=Form()):
715+
try:
716+
result = await asyncio.to_thread(get_ragas_metrics,question,context,answer,model)
717+
if result is None:
718+
return create_api_response('Failed', message='Failed to calculate metrics.',error="Ragas evaluation returned null")
719+
return create_api_response('Success',data=result,message=f"Status set to Reprocess for filename : {result}")
720+
except Exception as e:
721+
job_status = "Failed"
722+
message="Error while calculating evaluation metrics"
723+
error_message = str(e)
724+
logging.exception(f'{error_message}')
725+
return create_api_response(job_status, message=message, error=error_message)
726+
finally:
727+
gc.collect()
710728

711729
if __name__ == "__main__":
712730
uvicorn.run(app)

backend/src/QA_integration.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from langchain_text_splitters import TokenTextSplitter
2323
from langchain_core.messages import HumanMessage, AIMessage
2424
from langchain.chains import GraphCypherQAChain
25-
from langchain_community.chat_message_histories import ChatMessageHistory
25+
from langchain_community.chat_message_histories import ChatMessageHistory
26+
from langchain_core.callbacks import StdOutCallbackHandler, BaseCallbackHandler
2627

2728
# LangChain chat models
2829
from langchain_openai import ChatOpenAI, AzureChatOpenAI
@@ -38,13 +39,12 @@
3839
from src.shared.common_fn import load_embedding_model
3940
from src.shared.constants import *
4041
from src.graphDB_dataAccess import graphDBdataAccess
42+
from src.ragas_eval import get_ragas_metrics
4143
load_dotenv()
4244

4345
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
4446
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
4547

46-
47-
4848
class SessionChatHistory:
4949
history_dict = {}
5050

@@ -58,6 +58,17 @@ def get_chat_history(cls, session_id):
5858
logging.info(f"Retrieved existing ChatMessageHistory Local for session ID: {session_id}")
5959
return cls.history_dict[session_id]
6060

61+
class CustomCallback(BaseCallbackHandler):
62+
63+
def __init__(self):
64+
self.transformed_question = None
65+
66+
def on_llm_end(
67+
self,response, **kwargs: Any
68+
) -> None:
69+
logging.info("question transformed")
70+
self.transformed_question = response.generations[0][0].text.strip()
71+
6172
def get_history_by_session_id(session_id):
6273
try:
6374
return SessionChatHistory.get_chat_history(session_id)
@@ -250,21 +261,25 @@ def process_documents(docs, question, messages, llm, model,chat_mode_settings):
250261
logging.error(f"Error processing documents: {e}")
251262
raise
252263

253-
return content, result, total_tokens
264+
return content, result, total_tokens, formatted_docs
254265

255266
def retrieve_documents(doc_retriever, messages):
256267

257268
start_time = time.time()
258269
try:
259-
docs = doc_retriever.invoke({"messages": messages})
270+
handler = CustomCallback()
271+
docs = doc_retriever.invoke({"messages": messages},{"callbacks":[handler]})
272+
transformed_question = handler.transformed_question
273+
if transformed_question:
274+
logging.info(f"Transformed question : {transformed_question}")
260275
doc_retrieval_time = time.time() - start_time
261276
logging.info(f"Documents retrieved in {doc_retrieval_time:.2f} seconds")
262277

263278
except Exception as e:
264279
logging.error(f"Error retrieving documents: {e}")
265280
raise
266281

267-
return docs
282+
return docs,transformed_question
268283

269284
def create_document_retriever_chain(llm, retriever):
270285
try:
@@ -408,14 +423,19 @@ def process_chat_response(messages, history, question, model, graph, document_na
408423
try:
409424
llm, doc_retriever, model_version = setup_chat(model, graph, document_names, chat_mode_settings)
410425

411-
docs = retrieve_documents(doc_retriever, messages)
426+
docs,transformed_question = retrieve_documents(doc_retriever, messages)
412427

413428
if docs:
414-
content, result, total_tokens = process_documents(docs, question, messages, llm, model, chat_mode_settings)
429+
content, result, total_tokens,formatted_docs = process_documents(docs, question, messages, llm, model, chat_mode_settings)
415430
else:
416431
content = "I couldn't find any relevant documents to answer your question."
417432
result = {"sources": list(), "nodedetails": list(), "entities": list()}
418433
total_tokens = 0
434+
formatted_docs = ""
435+
436+
question = transformed_question if transformed_question else question
437+
# metrics = get_ragas_metrics(question,formatted_docs,content)
438+
# print(metrics)
419439

420440
ai_response = AIMessage(content=content)
421441
messages.append(ai_response)
@@ -424,19 +444,22 @@ def process_chat_response(messages, history, question, model, graph, document_na
424444
summarization_thread.start()
425445
logging.info("Summarization thread started.")
426446
# summarize_and_log(history, messages, llm)
427-
447+
metric_details = {"question":question,"contexts":formatted_docs,"answer":content}
428448
return {
429449
"session_id": "",
430450
"message": content,
431451
"info": {
452+
# "metrics" : metrics,
432453
"sources": result["sources"],
433454
"model": model_version,
434455
"nodedetails": result["nodedetails"],
435456
"total_tokens": total_tokens,
436457
"response_time": 0,
437458
"mode": chat_mode_settings["mode"],
438459
"entities": result["entities"],
460+
"metric_details": metric_details,
439461
},
462+
440463
"user": "chatbot"
441464
}
442465

@@ -446,13 +469,15 @@ def process_chat_response(messages, history, question, model, graph, document_na
446469
"session_id": "",
447470
"message": "Something went wrong",
448471
"info": {
472+
"metrics" : [],
449473
"sources": [],
450474
"nodedetails": [],
451475
"total_tokens": 0,
452476
"response_time": 0,
453477
"error": f"{type(e).__name__}: {str(e)}",
454478
"mode": chat_mode_settings["mode"],
455479
"entities": [],
480+
"metric_details": {},
456481
},
457482
"user": "chatbot"
458483
}

backend/src/ragas_eval.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import os
2+
import logging
3+
import time
4+
from typing import Dict, Tuple, Optional
5+
import boto3
6+
from datasets import Dataset
7+
from dotenv import load_dotenv
8+
from langchain_anthropic import ChatAnthropic
9+
from langchain_aws import ChatBedrock
10+
from langchain_community.chat_models import ChatOllama
11+
from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer
12+
from langchain_fireworks import ChatFireworks
13+
from langchain_google_vertexai import (
14+
ChatVertexAI,
15+
HarmBlockThreshold,
16+
HarmCategory,
17+
)
18+
from langchain_groq import ChatGroq
19+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
20+
from ragas import evaluate
21+
from ragas.metrics import answer_relevancy, context_utilization, faithfulness
22+
from src.shared.common_fn import load_embedding_model
23+
24+
load_dotenv()
25+
26+
# Constants for clarity and maintainability
27+
RAGAS_MODEL_VERSIONS = {
28+
"openai-gpt-3.5": "gpt-3.5-turbo-16k",
29+
"gemini-1.0-pro": "gemini-1.0-pro-001",
30+
"gemini-1.5-pro": "gemini-1.5-pro-002",
31+
"gemini-1.5-flash": "gemini-1.5-flash-002",
32+
"openai-gpt-4": "gpt-4-turbo-2024-04-09",
33+
"openai-gpt-4o-mini": "gpt-4o-mini-2024-07-18",
34+
"openai-gpt-4o": "gpt-4o-mini-2024-07-18",
35+
"diffbot": "gpt-4-turbo-2024-04-09",
36+
"groq-llama3": "groq_llama3_70b",
37+
}
38+
39+
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
40+
EMBEDDING_FUNCTION, _ = load_embedding_model(EMBEDDING_MODEL)
41+
42+
43+
def get_ragas_llm(model: str) -> Tuple[object, str]:
44+
"""Retrieves the specified language model. Improved error handling and structure."""
45+
env_key = f"LLM_MODEL_CONFIG_{model}"
46+
env_value = os.environ.get(env_key)
47+
logging.info(f"Loading model configuration: {env_key}")
48+
try:
49+
if "gemini" in model:
50+
credentials, project_id = google.auth.default()
51+
model_name = RAGAS_MODEL_VERSIONS[model]
52+
llm = ChatVertexAI(
53+
model_name=model_name,
54+
credentials=credentials,
55+
project=project_id,
56+
temperature=0,
57+
safety_settings={
58+
#setting safety to NONE for all categories. Consider reviewing this for production systems
59+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
60+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
61+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
62+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
63+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
64+
},
65+
)
66+
elif "openai" in model:
67+
model_name = RAGAS_MODEL_VERSIONS[model]
68+
llm = ChatOpenAI(
69+
api_key=os.environ.get("OPENAI_API_KEY"), model=model_name, temperature=0
70+
)
71+
72+
elif "azure" in model:
73+
model_name, api_endpoint, api_key, api_version = env_value.split(",")
74+
llm = AzureChatOpenAI(
75+
api_key=api_key,
76+
azure_endpoint=api_endpoint,
77+
azure_deployment=model_name,
78+
api_version=api_version,
79+
temperature=0,
80+
)
81+
elif "anthropic" in model:
82+
model_name, api_key = env_value.split(",")
83+
llm = ChatAnthropic(api_key=api_key, model=model_name, temperature=0)
84+
elif "fireworks" in model:
85+
model_name, api_key = env_value.split(",")
86+
llm = ChatFireworks(api_key=api_key, model=model_name)
87+
elif "groq" in model:
88+
model_name, base_url, api_key = env_value.split(",")
89+
llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0)
90+
elif "bedrock" in model:
91+
model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",")
92+
bedrock_client = boto3.client(
93+
service_name="bedrock-runtime",
94+
region_name=region_name,
95+
aws_access_key_id=aws_access_key,
96+
aws_secret_access_key=aws_secret_key,
97+
)
98+
llm = ChatBedrock(
99+
client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0)
100+
)
101+
elif "ollama" in model:
102+
model_name, base_url = env_value.split(",")
103+
llm = ChatOllama(base_url=base_url, model=model_name)
104+
elif "diffbot" in model:
105+
llm = DiffbotGraphTransformer(
106+
diffbot_api_key=os.environ.get("DIFFBOT_API_KEY"),
107+
extract_types=["entities", "facts"],
108+
)
109+
else:
110+
raise ValueError(f"Unsupported model: {model}")
111+
112+
logging.info(f"Model loaded - Model Version: {model}")
113+
return llm, model_name
114+
except (ValueError, KeyError) as e:
115+
logging.error(f"Error loading LLM: {e}")
116+
raise
117+
118+
119+
def get_ragas_metrics(
120+
question: str, context: str, answer: str, model: str
121+
) -> Optional[Dict[str, float]]:
122+
"""Calculates RAGAS metrics."""
123+
try:
124+
start_time = time.time()
125+
dataset = Dataset.from_dict(
126+
{"question": [question], "answer": [answer], "contexts": [[context]]}
127+
)
128+
logging.info("Dataset created successfully.")
129+
130+
llm, model_name = get_ragas_llm(model=model)
131+
logging.info(f"Evaluating with model: {model_name}")
132+
133+
score = evaluate(
134+
dataset=dataset,
135+
metrics=[faithfulness, answer_relevancy, context_utilization],
136+
llm=llm,
137+
embeddings=EMBEDDING_FUNCTION,
138+
)
139+
140+
score_dict = (
141+
score.to_pandas()[["faithfulness", "answer_relevancy", "context_utilization"]]
142+
.round(4)
143+
.to_dict(orient="records")[0]
144+
)
145+
end_time = time.time()
146+
logging.info(f"Evaluation completed in: {end_time - start_time:.2f} seconds")
147+
return score_dict
148+
except Exception as e:
149+
logging.exception(f"Error during metrics evaluation: {e}")
150+
return None

0 commit comments

Comments
 (0)