Skip to content

Commit c4bff62

Browse files
authored
Adding ShieldGemma 2 notebook to Responsible AI Toolkit docs (#564)
* Adding ShieldGemma 2 notebook to Responsible AI Toolkit docs * Addressing review feedback * fixing linter errors
1 parent 7868875 commit c4bff62

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "cLCmbOz_5tWH"
7+
},
8+
"source": [
9+
"##### Copyright 2025 Google LLC"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {
16+
"cellView": "form",
17+
"id": "vdPaBz5y5LHW"
18+
},
19+
"outputs": [],
20+
"source": [
21+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22+
"# you may not use this file except in compliance with the License.\n",
23+
"# You may obtain a copy of the License at\n",
24+
"#\n",
25+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
26+
"#\n",
27+
"# Unless required by applicable law or agreed to in writing, software\n",
28+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30+
"# See the License for the specific language governing permissions and\n",
31+
"# limitations under the License."
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"metadata": {
37+
"id": "3Zd1278P5wt_"
38+
},
39+
"source": [
40+
"# Evaluating content safety with ShieldGemma 2 and Hugging Face Transformers"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {
46+
"id": "2b40722aa1a9"
47+
},
48+
"source": [
49+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
50+
" <td>\n",
51+
" <a target=\"_blank\" href=\"https://ai.google.dev/responsible/docs/safeguards/shieldgemma2_on_huggingface\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
52+
" </td>\n",
53+
" <td>\n",
54+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/responsible/docs/safeguards/shieldgemma2_on_huggingface.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
55+
" </td>\n",
56+
" <td>\n",
57+
" <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/responsible/docs/safeguards/shieldgemma2_on_huggingface.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
58+
" </td>\n",
59+
"</table>"
60+
]
61+
},
62+
{
63+
"cell_type": "markdown",
64+
"metadata": {
65+
"id": "4IlgEYUj7xdW"
66+
},
67+
"source": [
68+
"The **ShieldGemma 2** model is trained to detect key harms detailed in the [model card](https://ai.google.dev/gemma/docs/shieldgemma/model_card_2). This guide demonstrates how to use Hugging Face Transformers to build robust data and models.\n",
69+
"\n",
70+
"Note that `ShieldGemma 2` is trained to classify only one harm type at a time, so you will need to make a separate call to `ShieldGemma 2` for each harm type you want to check against. You may have additional that you can use model tuning techniques on `ShieldGemma 2`."
71+
]
72+
},
73+
{
74+
"cell_type": "markdown",
75+
"metadata": {
76+
"id": "RhlnMQoK9fZG"
77+
},
78+
"source": [
79+
"# Supported safety checks\n",
80+
"\n",
81+
"**ShieldGemma2** is a model trained on Gemma 3's 4B IT checkpoint and is trained to detect and predict violations of key harm types listed below:\n",
82+
"\n",
83+
"* **Dangerous Content**: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).\n",
84+
"\n",
85+
"* **Sexually Explicit**: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).\n",
86+
"\n",
87+
"* **Violence/Gore**: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death).\n",
88+
"\n",
89+
"This serves as a foundation, but users can provide customized safety policies as input to the model, allowing for fine-grained control and specific use-case requirements."
90+
]
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"metadata": {
95+
"id": "t3aq-ToeAmRM"
96+
},
97+
"source": [
98+
"# Supported Use Case\n",
99+
"\n",
100+
"ShieldGemma 2 is should be used as an input filter to vision language models or as an output filter of image generation systems or both.** ShieldGemma 2 offers the following key advantages:\n",
101+
"\n",
102+
"* **Policy-Aware Classification**: ShieldGemma 2 accepts both a user-defined safety policy and an image as input, providing classifications for both real and generated images, tailored to the specific policy guidelines.\n",
103+
"* **Probability-Based Output and Thresholding**: ShieldGemma 2 outputs a probability score for its predictions, allowing downstream users to flexibly tune the classification threshold based on their specific use cases and risk tolerance. This enables a more nuanced and adaptable approach to safety classification.\n",
104+
"\n",
105+
"The input/output format are as follows:\n",
106+
"* **Input**: Image + Prompt Instruction with policy definition\n",
107+
"* **Output**: Probability of 'Yes'/'No' tokens, 'Yes' meaning that the image violated the specific policy. The higher the score for the 'Yes' token, the higher the model's confidence that the image violates the specified policy."
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {
113+
"id": "0WhRozADVJos"
114+
},
115+
"source": [
116+
"# Usage example"
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"metadata": {
123+
"id": "K_XERopLUZhk"
124+
},
125+
"outputs": [],
126+
"source": [
127+
"# @title install Hugging Face Transformers v4.50+\n",
128+
"! pip install -q 'transformers>=4.50.0'"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"metadata": {
135+
"id": "Qg-Hy0ffbwvE"
136+
},
137+
"outputs": [],
138+
"source": [
139+
"# @title Authenticate with Hugging Face Hub\n",
140+
"# @markdown ShieldGemma is a gated model. To access the weights, you must accept\n",
141+
"# @markdown the license on Hugging Face Hub under your account and then provide\n",
142+
"# @markdown an [Access Token](https://huggingface.co/docs/hub/en/security-tokens)\n",
143+
"# @markdown to authenticate with the Hugging Face Hub API. If using Colab, the\n",
144+
"# @markdown easiest way to do this is by creating a read-only token specifically\n",
145+
"# @markdown for Colab and setting this as the value of the `HF_TOKEN` secret;\n",
146+
"# @markdown this token will then be reusable across all Colab notebooks. Other\n",
147+
"# @markdown Python notebook platforms may provide a similar mechanism. For those\n",
148+
"# @markdown that do not, un-comment the lines in this cell to install the\n",
149+
"# @markdown Hugging Face Hub CLI and log in interactively.\n",
150+
"# ! pip install -q 'huggingface_hub[cli]'\n",
151+
"# ! huggingface-cli login"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"metadata": {
158+
"id": "40Rm46Xt7wqW"
159+
},
160+
"outputs": [],
161+
"source": [
162+
"from transformers import AutoProcessor, AutoModelForImageClassification\n",
163+
"import torch\n",
164+
"\n",
165+
"model_id = \"google/shieldgemma-2-4b-it\"\n",
166+
"\n",
167+
"processor = AutoProcessor.from_pretrained(model_id)\n",
168+
"model = AutoModelForImageClassification.from_pretrained(model_id)\n",
169+
"model.to(torch.device(\"cuda\"))"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {
176+
"id": "a436de5a4e95"
177+
},
178+
"outputs": [],
179+
"source": [
180+
"from PIL import Image\n",
181+
"import requests\n",
182+
"\n",
183+
"# The image included in this Colab is benign and will not violate any of\n",
184+
"# ShieldGemma's built-in content policies. Change this URL or otherwise update\n",
185+
"# this code to use an image that may be violative.\n",
186+
"url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg\"\n",
187+
"image = Image.open(requests.get(url, stream=True).raw)"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"id": "AK1PrHnYz4fv"
195+
},
196+
"outputs": [],
197+
"source": [
198+
"inputs = processor(images=[image], return_tensors=\"pt\").to(torch.device(\"cuda\"))\n",
199+
"\n",
200+
"with torch.no_grad():\n",
201+
" scores = model(**inputs)\n",
202+
"\n",
203+
"# `scores` is a `ShieldGemma2ImageClassifierOutputWithNoAttention` instance\n",
204+
"# continaing the logits and probabilities associated with the model predicting\n",
205+
"# the `Yes` or `No` tokens as the response to the prompt batch, captured in the\n",
206+
"# following properties.\n",
207+
"#\n",
208+
"# * `logits` (`torch.Tensor` of shape `(batch_size, 2)`): The first position\n",
209+
"# along dim=1 is the logits for the `Yes` token and the second position\n",
210+
"# along dim=1 is the logits for the `No` token.\n",
211+
"# * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`): The first\n",
212+
"# position along dim=1 is the probability of predicting the `Yes` token\n",
213+
"# and the second position along dim=1 is the probability of predicting the\n",
214+
"# `No` token.\n",
215+
"#\n",
216+
"# When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to\n",
217+
"# `len(images) * len(policies)`, and the order within the batch will be\n",
218+
"# img1_policy1, ... img1_policyN, ... imgM_policyN.\n",
219+
"print(scores.logits)\n",
220+
"print(scores.probabilities)\n",
221+
"\n",
222+
"# ShieldGemma prompts are constructed such that predicting the `Yes` token means\n",
223+
"# the content violates the policy. If you are only interested in the violative\n",
224+
"# condition, you can extract only that slice from the output tensors.\n",
225+
"p_violated = scores.probabilities[:, 0]\n",
226+
"print(p_violated)\n"
227+
]
228+
}
229+
],
230+
"metadata": {
231+
"accelerator": "GPU",
232+
"colab": {
233+
"name": "shieldgemma2_on_huggingface.ipynb",
234+
"toc_visible": true
235+
},
236+
"kernelspec": {
237+
"display_name": "Python 3",
238+
"name": "python3"
239+
}
240+
},
241+
"nbformat": 4,
242+
"nbformat_minor": 0
243+
}

0 commit comments

Comments
 (0)