1
+ import os
2
+ import logging
3
+ import time
4
+ from typing import Dict , Tuple , Optional
5
+ import boto3
6
+ from datasets import Dataset
7
+ from dotenv import load_dotenv
8
+ from langchain_anthropic import ChatAnthropic
9
+ from langchain_aws import ChatBedrock
10
+ from langchain_community .chat_models import ChatOllama
11
+ from langchain_experimental .graph_transformers .diffbot import DiffbotGraphTransformer
12
+ from langchain_fireworks import ChatFireworks
13
+ from langchain_google_vertexai import (
14
+ ChatVertexAI ,
15
+ HarmBlockThreshold ,
16
+ HarmCategory ,
17
+ )
18
+ from langchain_groq import ChatGroq
19
+ from langchain_openai import AzureChatOpenAI , ChatOpenAI
20
+ from ragas import evaluate
21
+ from ragas .metrics import answer_relevancy , context_utilization , faithfulness
22
+ from src .shared .common_fn import load_embedding_model
23
+
24
+ load_dotenv ()
25
+
26
+ # Constants for clarity and maintainability
27
+ RAGAS_MODEL_VERSIONS = {
28
+ "openai-gpt-3.5" : "gpt-3.5-turbo-16k" ,
29
+ "gemini-1.0-pro" : "gemini-1.0-pro-001" ,
30
+ "gemini-1.5-pro" : "gemini-1.5-pro-002" ,
31
+ "gemini-1.5-flash" : "gemini-1.5-flash-002" ,
32
+ "openai-gpt-4" : "gpt-4-turbo-2024-04-09" ,
33
+ "openai-gpt-4o-mini" : "gpt-4o-mini-2024-07-18" ,
34
+ "openai-gpt-4o" : "gpt-4o-mini-2024-07-18" ,
35
+ "diffbot" : "gpt-4-turbo-2024-04-09" ,
36
+ "groq-llama3" : "groq_llama3_70b" ,
37
+ }
38
+
39
+ EMBEDDING_MODEL = os .getenv ("EMBEDDING_MODEL" )
40
+ EMBEDDING_FUNCTION , _ = load_embedding_model (EMBEDDING_MODEL )
41
+
42
+
43
+ def get_ragas_llm (model : str ) -> Tuple [object , str ]:
44
+ """Retrieves the specified language model. Improved error handling and structure."""
45
+ env_key = f"LLM_MODEL_CONFIG_{ model } "
46
+ env_value = os .environ .get (env_key )
47
+ logging .info (f"Loading model configuration: { env_key } " )
48
+ try :
49
+ if "gemini" in model :
50
+ credentials , project_id = google .auth .default ()
51
+ model_name = RAGAS_MODEL_VERSIONS [model ]
52
+ llm = ChatVertexAI (
53
+ model_name = model_name ,
54
+ credentials = credentials ,
55
+ project = project_id ,
56
+ temperature = 0 ,
57
+ safety_settings = {
58
+ #setting safety to NONE for all categories. Consider reviewing this for production systems
59
+ HarmCategory .HARM_CATEGORY_UNSPECIFIED : HarmBlockThreshold .BLOCK_NONE ,
60
+ HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT : HarmBlockThreshold .BLOCK_NONE ,
61
+ HarmCategory .HARM_CATEGORY_HATE_SPEECH : HarmBlockThreshold .BLOCK_NONE ,
62
+ HarmCategory .HARM_CATEGORY_HARASSMENT : HarmBlockThreshold .BLOCK_NONE ,
63
+ HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
64
+ },
65
+ )
66
+ elif "openai" in model :
67
+ model_name = RAGAS_MODEL_VERSIONS [model ]
68
+ llm = ChatOpenAI (
69
+ api_key = os .environ .get ("OPENAI_API_KEY" ), model = model_name , temperature = 0
70
+ )
71
+
72
+ elif "azure" in model :
73
+ model_name , api_endpoint , api_key , api_version = env_value .split ("," )
74
+ llm = AzureChatOpenAI (
75
+ api_key = api_key ,
76
+ azure_endpoint = api_endpoint ,
77
+ azure_deployment = model_name ,
78
+ api_version = api_version ,
79
+ temperature = 0 ,
80
+ )
81
+ elif "anthropic" in model :
82
+ model_name , api_key = env_value .split ("," )
83
+ llm = ChatAnthropic (api_key = api_key , model = model_name , temperature = 0 )
84
+ elif "fireworks" in model :
85
+ model_name , api_key = env_value .split ("," )
86
+ llm = ChatFireworks (api_key = api_key , model = model_name )
87
+ elif "groq" in model :
88
+ model_name , base_url , api_key = env_value .split ("," )
89
+ llm = ChatGroq (api_key = api_key , model_name = model_name , temperature = 0 )
90
+ elif "bedrock" in model :
91
+ model_name , aws_access_key , aws_secret_key , region_name = env_value .split ("," )
92
+ bedrock_client = boto3 .client (
93
+ service_name = "bedrock-runtime" ,
94
+ region_name = region_name ,
95
+ aws_access_key_id = aws_access_key ,
96
+ aws_secret_access_key = aws_secret_key ,
97
+ )
98
+ llm = ChatBedrock (
99
+ client = bedrock_client , model_id = model_name , model_kwargs = dict (temperature = 0 )
100
+ )
101
+ elif "ollama" in model :
102
+ model_name , base_url = env_value .split ("," )
103
+ llm = ChatOllama (base_url = base_url , model = model_name )
104
+ elif "diffbot" in model :
105
+ llm = DiffbotGraphTransformer (
106
+ diffbot_api_key = os .environ .get ("DIFFBOT_API_KEY" ),
107
+ extract_types = ["entities" , "facts" ],
108
+ )
109
+ else :
110
+ raise ValueError (f"Unsupported model: { model } " )
111
+
112
+ logging .info (f"Model loaded - Model Version: { model } " )
113
+ return llm , model_name
114
+ except (ValueError , KeyError ) as e :
115
+ logging .error (f"Error loading LLM: { e } " )
116
+ raise
117
+
118
+
119
+ def get_ragas_metrics (
120
+ question : str , context : str , answer : str , model : str
121
+ ) -> Optional [Dict [str , float ]]:
122
+ """Calculates RAGAS metrics."""
123
+ try :
124
+ start_time = time .time ()
125
+ dataset = Dataset .from_dict (
126
+ {"question" : [question ], "answer" : [answer ], "contexts" : [[context ]]}
127
+ )
128
+ logging .info ("Dataset created successfully." )
129
+
130
+ llm , model_name = get_ragas_llm (model = model )
131
+ logging .info (f"Evaluating with model: { model_name } " )
132
+
133
+ score = evaluate (
134
+ dataset = dataset ,
135
+ metrics = [faithfulness , answer_relevancy , context_utilization ],
136
+ llm = llm ,
137
+ embeddings = EMBEDDING_FUNCTION ,
138
+ )
139
+
140
+ score_dict = (
141
+ score .to_pandas ()[["faithfulness" , "answer_relevancy" , "context_utilization" ]]
142
+ .round (4 )
143
+ .to_dict (orient = "records" )[0 ]
144
+ )
145
+ end_time = time .time ()
146
+ logging .info (f"Evaluation completed in: { end_time - start_time :.2f} seconds" )
147
+ return score_dict
148
+ except Exception as e :
149
+ logging .exception (f"Error during metrics evaluation: { e } " )
150
+ return None
0 commit comments