Skip to content

Commit 34430b9

Browse files
Handle GIFs correct in gr.Image preprocessing (#8589)
* handle gifs correct in image preprocessing * add changeset * fix * add test * add test * docstring * add docs * image * revert * change * add changeset --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 797621b commit 34430b9

File tree

6 files changed

+61
-47
lines changed

6 files changed

+61
-47
lines changed

.changeset/fluffy-crabs-sleep.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"gradio": patch
3+
"website": patch
4+
---
5+
6+
fix:Handle GIFs correct in `gr.Image` preprocessing

gradio/components/image.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def __init__(
7171
"""
7272
Parameters:
7373
value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.
74-
format: Format to save image if it does not already have a valid format (e.g. if the image is being returned to the frontend as a numpy array or PIL Image). The format should be supported by the PIL library. This parameter has no effect on SVG files.
74+
format: File format (e.g. "png" or "gif") to save image if it does not already have a valid format (e.g. if the image is being returned to the frontend as a numpy array or PIL Image). The format should be supported by the PIL library. This parameter has no effect on SVG files.
7575
height: The height of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.
7676
width: The width of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.
77-
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning.
77+
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files.
7878
sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"].
79-
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned.
79+
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil".
8080
label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
8181
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
8282
show_label: if True, will display label.
@@ -181,9 +181,10 @@ def preprocess(
181181
warnings.warn(
182182
f"Failed to transpose image {file_path} based on EXIF data."
183183
)
184-
with warnings.catch_warnings():
185-
warnings.simplefilter("ignore")
186-
im = im.convert(self.image_mode)
184+
if suffix.lower() != "gif" and im is not None:
185+
with warnings.catch_warnings():
186+
warnings.simplefilter("ignore")
187+
im = im.convert(self.image_mode)
187188
return image_utils.format_image(
188189
im,
189190
cast(Literal["numpy", "pil", "filepath"], self.type),

gradio/processing_utils.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import httpx
1818
import numpy as np
1919
from gradio_client import utils as client_utils
20-
from PIL import Image, ImageOps, PngImagePlugin
20+
from PIL import Image, ImageOps, ImageSequence, PngImagePlugin
2121

2222
from gradio import utils, wasm_utils
2323
from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData
@@ -138,7 +138,7 @@ def encode_plot_to_base64(plt, format: str = "png"):
138138
plt.savefig(output_bytes, format=fmt)
139139
bytes_data = output_bytes.getvalue()
140140
base64_str = str(base64.b64encode(bytes_data), "utf-8")
141-
return output_base64(base64_str, fmt)
141+
return f"data:image/{format or 'png'};base64,{base64_str}"
142142

143143

144144
def get_pil_exif_bytes(pil_image):
@@ -158,34 +158,25 @@ def get_pil_metadata(pil_image):
158158

159159
def encode_pil_to_bytes(pil_image, format="png"):
160160
with BytesIO() as output_bytes:
161-
if format == "png":
162-
params = {"pnginfo": get_pil_metadata(pil_image)}
161+
if format.lower() == "gif":
162+
frames = [frame.copy() for frame in ImageSequence.Iterator(pil_image)]
163+
frames[0].save(
164+
output_bytes,
165+
format=format,
166+
save_all=True,
167+
append_images=frames[1:],
168+
loop=0,
169+
)
163170
else:
164-
exif = get_pil_exif_bytes(pil_image)
165-
params = {"exif": exif} if exif else {}
166-
pil_image.save(output_bytes, format, **params)
171+
if format.lower() == "png":
172+
params = {"pnginfo": get_pil_metadata(pil_image)}
173+
else:
174+
exif = get_pil_exif_bytes(pil_image)
175+
params = {"exif": exif} if exif else {}
176+
pil_image.save(output_bytes, format, **params)
167177
return output_bytes.getvalue()
168178

169179

170-
def encode_pil_to_base64(pil_image, format="png"):
171-
bytes_data = encode_pil_to_bytes(pil_image, format)
172-
base64_str = str(base64.b64encode(bytes_data), "utf-8")
173-
return output_base64(base64_str, format)
174-
175-
176-
def encode_array_to_base64(image_array, format="png"):
177-
with BytesIO() as output_bytes:
178-
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
179-
pil_image.save(output_bytes, format)
180-
bytes_data = output_bytes.getvalue()
181-
base64_str = str(base64.b64encode(bytes_data), "utf-8")
182-
return output_base64(base64_str, format)
183-
184-
185-
def output_base64(data, format=None) -> str:
186-
return f"data:image/{format or 'png'};base64,{data}"
187-
188-
189180
def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str:
190181
sha1 = hashlib.sha1()
191182
with open(file_path, "rb") as f:

gradio/test_data/rectangles.gif

1.66 KB
Loading

js/_website/src/lib/templates/gradio/03_components/image.svx

+23
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@ def predict(···) -> np.ndarray | PIL.Image.Image | str | Path | None
6969
<ShortcutTable shortcuts={obj.string_shortcuts} />
7070
{/if}
7171

72+
73+
### `GIF` and `SVG` Image Formats
74+
75+
The `gr.Image` component can process or display any image format that is [supported by the PIL library](https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html), including animated GIFs. In addition, it also supports the SVG image format.
76+
77+
When the `gr.Image` component is used as an input component, the image is converted into a `str` filepath, a `PIL.Image` object, or a `numpy.array`, depending on the `type` parameter. However, animated GIF and SVG images are treated differently:
78+
79+
* Animated `GIF` images can only be converted to `str` filepaths or `PIL.Image` objects. If they are converted to a `numpy.array` (which is the default behavior), only the first frame will be used. So if your demo expects an input `GIF` image, make sure to set the `type` parameter accordingly, e.g.
80+
81+
```py
82+
import gradio as gr
83+
84+
demo = gr.Interface(
85+
fn=lambda x:x,
86+
inputs=gr.Image(type="filepath"),
87+
outputs=gr.Image()
88+
)
89+
90+
demo.launch()
91+
```
92+
93+
* For `SVG` images, the `type` parameter is ignored altogether and the image is always returned as an image filepath. This is because `SVG` images cannot be processed as `PIL.Image` or `numpy.array` objects.
94+
7295
{#if obj.demos && obj.demos.length > 0}
7396
<!--- Demos -->
7497
### Demos

test/test_processing_utils.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import shutil
33
import tempfile
4-
from copy import deepcopy
54
from pathlib import Path
65
from unittest.mock import patch
76

@@ -114,20 +113,6 @@ def test_encode_plot_to_base64(self):
114113
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAo"
115114
)
116115

117-
def test_encode_array_to_base64(self):
118-
img = Image.open("gradio/test_data/test_image.png")
119-
img = img.convert("RGB")
120-
numpy_data = np.asarray(img, dtype=np.uint8)
121-
output_base64 = processing_utils.encode_array_to_base64(numpy_data)
122-
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
123-
124-
def test_encode_pil_to_base64(self):
125-
img = Image.open("gradio/test_data/test_image.png")
126-
img = img.convert("RGB")
127-
img.info = {} # Strip metadata
128-
output_base64 = processing_utils.encode_pil_to_base64(img)
129-
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
130-
131116
def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir):
132117
input_img = Image.open("gradio/test_data/test_image.png")
133118
input_img = input_img.convert("RGB")
@@ -141,6 +126,14 @@ def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir):
141126

142127
assert output_img.info == input_img.info
143128

129+
def test_save_pil_to_file_keeps_all_gif_frames(self, gradio_temp_dir):
130+
input_img = Image.open("gradio/test_data/rectangles.gif")
131+
file_obj = processing_utils.save_pil_to_cache(
132+
input_img, cache_dir=gradio_temp_dir, format="gif"
133+
)
134+
output_img = Image.open(file_obj)
135+
assert output_img.n_frames == input_img.n_frames == 3
136+
144137
def test_np_pil_encode_to_the_same(self, gradio_temp_dir):
145138
arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8)
146139
pil = Image.fromarray(arr)

0 commit comments

Comments
 (0)