Skip to content

Commit 3032969

Browse files
authored
Google suts (#559)
* basic google sut * Add google-specific request class * add google-specifc responses * add more suts * undo accidental commit * sut sets response text to REFUSAL_RESPONSE constant when stopped early for safety * forgot to make content optional * 2 diff suts: default and disabled safety settings
1 parent 12dff34 commit 3032969

File tree

7 files changed

+582
-4
lines changed

7 files changed

+582
-4
lines changed

plugins/google/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Plugin for interacting with Google API.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import google.generativeai as genai # type: ignore
2+
from abc import abstractmethod
3+
from google.generativeai.types import HarmCategory, HarmBlockThreshold # type: ignore
4+
from pydantic import BaseModel
5+
from typing import Dict, List, Optional
6+
7+
from modelgauge.general import APIException
8+
from modelgauge.prompt import TextPrompt
9+
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
10+
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTCompletion, SUTResponse
11+
from modelgauge.sut_capabilities import AcceptsTextPrompt
12+
from modelgauge.sut_decorator import modelgauge_sut
13+
from modelgauge.sut_registry import SUTS
14+
15+
FinishReason = genai.protos.Candidate.FinishReason
16+
GEMINI_HARM_CATEGORIES = [
17+
HarmCategory.HARM_CATEGORY_HATE_SPEECH,
18+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
19+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
20+
HarmCategory.HARM_CATEGORY_HARASSMENT,
21+
]
22+
23+
24+
class GoogleAiApiKey(RequiredSecret):
25+
@classmethod
26+
def description(cls) -> SecretDescription:
27+
return SecretDescription(
28+
scope="google_ai",
29+
key="api_key",
30+
instructions="See https://aistudio.google.com/app/apikey",
31+
)
32+
33+
34+
class GoogleGenAiConfig(BaseModel):
35+
"""Generation config for Google Gen AI requests.
36+
37+
Based on https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
38+
"""
39+
40+
stop_sequences: Optional[List[str]] = None
41+
max_output_tokens: Optional[int] = None
42+
temperature: Optional[float] = None
43+
top_p: Optional[float] = None
44+
top_k: Optional[int] = None
45+
presence_penalty: Optional[float] = None
46+
frequency_penalty: Optional[float] = None
47+
48+
49+
class GoogleGenAiRequest(BaseModel):
50+
contents: str
51+
generation_config: GoogleGenAiConfig
52+
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
53+
54+
55+
class GoogleGenAiResponse(BaseModel):
56+
class Candidate(BaseModel):
57+
content: Optional[Dict] = None
58+
finish_reason: int
59+
60+
candidates: List[Candidate]
61+
usage_metadata: Dict
62+
63+
64+
class GoogleGenAiBaseSUT(PromptResponseSUT[GoogleGenAiRequest, GoogleGenAiResponse]):
65+
def __init__(self, uid: str, model_name: str, api_key: GoogleAiApiKey):
66+
super().__init__(uid)
67+
self.model_name = model_name
68+
self.model: Optional[genai.GenerativeModel] = None
69+
genai.configure(api_key=api_key.value)
70+
71+
@property
72+
@abstractmethod
73+
def refusal_finish_reasons(self):
74+
"""List of finish reasons that should be treated as refusal responses."""
75+
pass
76+
77+
@property
78+
@abstractmethod
79+
def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
80+
"""Safety settings that are included in SUT requests."""
81+
pass
82+
83+
def _load_client(self) -> genai.GenerativeModel:
84+
return genai.GenerativeModel(self.model_name)
85+
86+
def translate_text_prompt(self, prompt: TextPrompt) -> GoogleGenAiRequest:
87+
generation_config = GoogleGenAiConfig(
88+
stop_sequences=prompt.options.stop_sequences,
89+
max_output_tokens=prompt.options.max_tokens,
90+
temperature=prompt.options.temperature,
91+
top_p=prompt.options.top_p,
92+
top_k=prompt.options.top_k_per_token,
93+
presence_penalty=prompt.options.presence_penalty,
94+
frequency_penalty=prompt.options.frequency_penalty,
95+
)
96+
return GoogleGenAiRequest(
97+
contents=prompt.text, generation_config=generation_config, safety_settings=self.safety_settings
98+
)
99+
100+
def evaluate(self, request: GoogleGenAiRequest) -> GoogleGenAiResponse:
101+
if self.model is None:
102+
# Handle lazy init.
103+
self.model = self._load_client()
104+
response = self.model.generate_content(**request.model_dump(exclude_none=True))
105+
# Convert to pydantic model
106+
return GoogleGenAiResponse(**response.to_dict())
107+
108+
def translate_response(self, request: GoogleGenAiRequest, response: GoogleGenAiResponse) -> SUTResponse:
109+
completions = []
110+
for candidate in response.candidates:
111+
if candidate.finish_reason in self.refusal_finish_reasons:
112+
completions.append(SUTCompletion(text=REFUSAL_RESPONSE))
113+
elif candidate.content is not None:
114+
completions.append(SUTCompletion(text=candidate.content["parts"][0]["text"]))
115+
else:
116+
raise APIException(
117+
f"Unexpected candidate in response from GoogleGenAiSUT {self.uid}: {candidate}. "
118+
f"The candidate does not have any content,"
119+
f" but it's finish reason {candidate.finish_reason} does not qualify as a refusal."
120+
)
121+
return SUTResponse(completions=completions)
122+
123+
124+
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
125+
class GoogleGenAiDefaultSUT(GoogleGenAiBaseSUT):
126+
"""SUT for Google Generative AI model with the model's default safety settings.
127+
128+
Finish reasons related to safety are treated as refusal responses."""
129+
130+
@property
131+
def refusal_finish_reasons(self):
132+
# Taken from https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse#finishreason
133+
return [
134+
FinishReason.SAFETY,
135+
FinishReason.RECITATION,
136+
FinishReason.BLOCKLIST,
137+
FinishReason.PROHIBITED_CONTENT,
138+
FinishReason.SPII,
139+
FinishReason.OTHER,
140+
]
141+
142+
@property
143+
def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
144+
return None
145+
146+
147+
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
148+
class GoogleGeminiDisabledSafetySettingsSUT(GoogleGenAiBaseSUT):
149+
"""SUT for Google Gemini model that removes that harm block threshold for all Gemini-specific harm categories."""
150+
151+
@property
152+
def refusal_finish_reasons(self):
153+
"""Model should not refuse for any safety-related reasons."""
154+
return []
155+
156+
@property
157+
def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
158+
return {harm: HarmBlockThreshold.BLOCK_NONE for harm in GEMINI_HARM_CATEGORIES}
159+
160+
161+
gemini_models = ["gemini-1.5-flash", "gemini-1.0-pro", "gemini-1.5-pro"]
162+
for model in gemini_models:
163+
SUTS.register(GoogleGenAiDefaultSUT, model, model, InjectSecret(GoogleAiApiKey))
164+
SUTS.register(
165+
GoogleGeminiDisabledSafetySettingsSUT, f"{model}-disabled_safety", model, InjectSecret(GoogleAiApiKey)
166+
)

plugins/google/pyproject.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[tool.poetry]
2+
name = "modelgauge-google"
3+
version = "0.6.3"
4+
description = ""
5+
authors = ["MLCommons AI Safety <[email protected]>"]
6+
readme = "README.md"
7+
packages = [{include = "modelgauge"}]
8+
9+
[tool.poetry.dependencies]
10+
python = "^3.10"
11+
google-generativeai = "^0.8.0"
12+
13+
14+
[build-system]
15+
requires = ["poetry-core"]
16+
build-backend = "poetry.core.masonry.api"

0 commit comments

Comments
 (0)