Skip to content

Commit 28dca3c

Browse files
stevhliupatrickvonplaten
authored and
Jimmy
committed
[docs] AutoPipeline tutorial (huggingface#4273)
* first draft * tidy api * apply feedback * mdx to md * apply feedback * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent f712f56 commit 28dca3c

File tree

4 files changed

+202
-58
lines changed

4 files changed

+202
-58
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
title: Overview
1414
- local: using-diffusers/write_own_pipeline
1515
title: Understanding models and schedulers
16+
- local: tutorials/autopipeline
17+
title: AutoPipeline
1618
- local: tutorials/basic_training
1719
title: Train a diffusion model
1820
title: Tutorials

docs/source/en/api/pipelines/auto_pipeline.md

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,41 @@ specific language governing permissions and limitations under the License.
1212

1313
# AutoPipeline
1414

15-
In many cases, one checkpoint can be used for multiple tasks. For example, you may be able to use the same checkpoint for Text-to-Image, Image-to-Image, and Inpainting. However, you'll need to know the pipeline class names linked to your checkpoint.
15+
`AutoPipeline` is designed to:
1616

17-
AutoPipeline is designed to make it easy for you to use multiple pipelines in your workflow. We currently provide 3 AutoPipeline classes to perform three different tasks, i.e. [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]. You'll need to choose the AutoPipeline class based on the task you want to perform and use it to automatically retrieve the relevant pipeline given the name/path to the pre-trained weights.
17+
1. make it easy for you to load a checkpoint for a task without knowing the specific pipeline class to use
18+
2. use multiple pipelines in your workflow
1819

19-
For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do
20+
Based on the task, the `AutoPipeline` class automatically retrieves the relevant pipeline given the name or path to the pretrained weights with the `from_pretrained()` method.
2021

21-
```python
22-
from diffusers import PipelineForImageToImage
22+
To seamlessly switch between tasks with the same checkpoint without reallocating additional memory, use the `from_pipe()` method to transfer the components from the original pipeline to the new one.
2323

24-
pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5")
24+
```py
25+
from diffusers import AutoPipelineForText2Image
26+
import torch
27+
28+
pipeline = AutoPipelineForText2Image.from_pretrained(
29+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
30+
).to("cuda")
31+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
32+
33+
image = pipeline(prompt, num_inference_steps=25).images[0]
2534
```
2635

27-
It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do
36+
<Tip>
2837

29-
```python
30-
from diffusers import PipelineForInpainting
38+
Check out the [AutoPipeline](/tutorials/autopipeline) tutorial to learn how to use this API!
3139

32-
pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_i2i)
33-
```
34-
All the components will be transferred to the inpainting pipeline with zero cost.
40+
</Tip>
3541

42+
`AutoPipeline` supports text-to-image, image-to-image, and inpainting for the following diffusion models:
3643

37-
Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting tasks for below diffusion models:
38-
- [stable Diffusion](./stable_diffusion)
39-
- [Stable Diffusion Controlnet](./api/pipelines/controlnet)
40-
- [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl)
41-
- [IF](./if)
44+
- [Stable Diffusion](./stable_diffusion)
45+
- [ControlNet](./api/pipelines/controlnet)
46+
- [Stable Diffusion XL (SDXL)](./stable_diffusion/stable_diffusion_xl)
47+
- [DeepFloyd IF](./if)
4248
- [Kandinsky](./kandinsky)
43-
- [Kandinsky 2.2](./kandinsky)
49+
- [Kandinsky 2.2](./kandinsky#kandinsky-22)
4450

4551

4652
## AutoPipelineForText2Image
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# AutoPipeline
2+
3+
🤗 Diffusers is able to complete many different tasks, and you can often reuse the same pretrained weights for multiple tasks such as text-to-image, image-to-image, and inpainting. If you're new to the library and diffusion models though, it may be difficult to know which pipeline to use for a task. For example, if you're using the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint for text-to-image, you might not know that you could also use it for image-to-image and inpainting by loading the checkpoint with the [`StableDiffusionImg2ImgPipeline`] and [`StableDiffusionInpaintPipeline`] classes respectively.
4+
5+
The `AutoPipeline` class is designed to simplify the variety of pipelines in 🤗 Diffusers. It is a generic, *task-first* pipeline that lets you focus on the task. The `AutoPipeline` automatically detects the correct pipeline class to use, which makes it easier to load a checkpoint for a task without knowing the specific pipeline class name.
6+
7+
<Tip>
8+
9+
Take a look at the [AutoPipeline](./pipelines/auto_pipeline) reference to see which tasks are supported. Currently, it supports text-to-image, image-to-image, and inpainting.
10+
11+
</Tip>
12+
13+
This tutorial shows you how to use an `AutoPipeline` to automatically infer the pipeline class to load for a specific task, given the pretrained weights.
14+
15+
## Choose an AutoPipeline for your task
16+
17+
Start by picking a checkpoint. For example, if you're interested in text-to-image with the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint, use [`AutoPipelineForText2Image`]:
18+
19+
```py
20+
from diffusers import AutoPipelineForText2Image
21+
import torch
22+
23+
pipeline = AutoPipelineForText2Image.from_pretrained(
24+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
25+
).to("cuda")
26+
prompt = "peasant and dragon combat, wood cutting style, viking era, bevel with rune"
27+
28+
image = pipeline(prompt, num_inference_steps=25).images[0]
29+
```
30+
31+
<div class="flex justify-center">
32+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png" alt="generated image of peasant fighting dragon in wood cutting style"/>
33+
</div>
34+
35+
Under the hood, [`AutoPipelineForText2Image`]:
36+
37+
1. automatically detects a `"stable-diffusion"` class from the [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) file
38+
2. loads the corresponding text-to-image [`StableDiffusionPipline`] based on the `"stable-diffusion"` class name
39+
40+
Likewise, for image-to-image, [`AutoPipelineForImage2Image`] detects a `"stable-diffusion"` checkpoint from the `model_index.json` file and it'll load the corresponding [`StableDiffusionImg2ImgPipeline`] behind the scenes. You can also pass any additional arguments specific to the pipeline class such as `strength`, which determines the amount of noise or variation added to an input image:
41+
42+
```py
43+
from diffusers import AutoPipelineForImage2Image
44+
45+
pipeline = AutoPipelineForImage2Image.from_pretrained(
46+
"runwayml/stable-diffusion-v1-5",
47+
torch_dtype=torch.float16,
48+
use_safetensors=True,
49+
).to("cuda")
50+
prompt = "a portrait of a dog wearing a pearl earring"
51+
52+
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg"
53+
54+
response = requests.get(url)
55+
image = Image.open(BytesIO(response.content)).convert("RGB")
56+
image.thumbnail((768, 768))
57+
58+
image = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0]
59+
```
60+
61+
<div class="flex justify-center">
62+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png" alt="generated image of a vermeer portrait of a dog wearing a pearl earring"/>
63+
</div>
64+
65+
And if you want to do inpainting, then [`AutoPipelineForInpainting`] loads the underlying [`StableDiffusionInpaintPipeline`] class in the same way:
66+
67+
```py
68+
from diffusers import AutoPipelineForInpainting
69+
from diffusers.utils import load_image
70+
71+
pipeline = AutoPipelineForInpainting.from_pretrained(
72+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
73+
).to("cuda")
74+
75+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
76+
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
77+
78+
init_image = load_image(img_url).convert("RGB")
79+
mask_image = load_image(mask_url).convert("RGB")
80+
81+
prompt = "A majestic tiger sitting on a bench"
82+
image = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
83+
```
84+
85+
<div class="flex justify-center">
86+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png" alt="generated image of a tiger sitting on a bench"/>
87+
</div>
88+
89+
If you try to load an unsupported checkpoint, it'll throw an error:
90+
91+
```py
92+
from diffusers import AutoPipelineForImage2Image
93+
import torch
94+
95+
pipeline = AutoPipelineForImage2Image.from_pretrained(
96+
"openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
97+
)
98+
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
99+
```
100+
101+
## Use multiple pipelines
102+
103+
For some workflows or if you're loading many pipelines, it is more memory-efficient to reuse the same components from a checkpoint instead of reloading them which would unnecessarily consume additional memory. For example, if you're using a checkpoint for text-to-image and you want to use it again for image-to-image, use the [`~AutoPipelineForImage2Image.from_pipe`] method. This method creates a new pipeline from the components of a previously loaded pipeline at no additional memory cost.
104+
105+
The [`~AutoPipelineForImage2Image.from_pipe`] method detects the original pipeline class and maps it to the new pipeline class corresponding to the task you want to do. For example, if you load a `"stable-diffusion"` class pipeline for text-to-image:
106+
107+
```py
108+
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
109+
110+
pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
111+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
112+
)
113+
print(type(pipeline_text2img))
114+
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>"
115+
```
116+
117+
Then [`~AutoPipelineForImage2Image.from_pipe`] maps the original `"stable-diffusion"` pipeline class to [`StableDiffusionImg2ImgPipeline`]:
118+
119+
```py
120+
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
121+
print(type(pipeline_img2img))
122+
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'>"
123+
```
124+
125+
If you passed an optional argument - like disabling the safety checker - to the original pipeline, this argument is also passed on to the new pipeline:
126+
127+
```py
128+
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
129+
130+
pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
131+
"runwayml/stable-diffusion-v1-5",
132+
torch_dtype=torch.float16,
133+
use_safetensors=True,
134+
requires_safety_checker=False,
135+
).to("cuda")
136+
137+
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
138+
print(pipe.config.requires_safety_checker)
139+
"False"
140+
```
141+
142+
You can overwrite any of the arguments and even configuration from the original pipeline if you want to change the behavior of the new pipeline. For example, to turn the safety checker back on and add the `strength` argument:
143+
144+
```py
145+
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3)
146+
```

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,11 @@ def _get_signature_keys(obj):
158158
class AutoPipelineForText2Image(ConfigMixin):
159159
r"""
160160
161-
AutoPipeline for text-to-image generation.
161+
[`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
162+
specific underlying pipeline class is automatically selected from either the
163+
[`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods.
162164
163-
[`AutoPipelineForText2Image`] is a generic pipeline class that will be instantiated as one of the text-to-image
164-
pipeline class in diffusers.
165-
166-
The pipeline type (for example [`StableDiffusionPipeline`]) is automatically selected when created with the
167-
AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path) or
168-
AutoPipelineForText2Image.from_pipe(pipeline) class methods .
169-
170-
This class cannot be instantiated using __init__() (throws an error).
165+
This class cannot be instantiated using `__init__()` (throws an error).
171166
172167
Class attributes:
173168
@@ -297,7 +292,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
297292
>>> from diffusers import AutoPipelineForText2Image
298293
299294
>>> pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
300-
>>> print(pipeline.__class__)
295+
>>> image = pipeline(prompt).images[0]
301296
```
302297
"""
303298
config = cls.load_config(pretrained_model_or_path)
@@ -328,13 +323,14 @@ def from_pipe(cls, pipeline, **kwargs):
328323
an instantiated `DiffusionPipeline` object
329324
330325
```py
331-
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImageToImage
326+
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
332327
333328
>>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
334329
... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
335330
... )
336331
337-
>>> pipe_t2i = AutoPipelineForText2Image.from_pipe(pipe_t2i)
332+
>>> pipe_t2i = AutoPipelineForText2Image.from_pipe(pipe_i2i)
333+
>>> image = pipe_t2i(prompt).images[0]
338334
```
339335
"""
340336

@@ -401,16 +397,11 @@ def from_pipe(cls, pipeline, **kwargs):
401397
class AutoPipelineForImage2Image(ConfigMixin):
402398
r"""
403399
404-
AutoPipeline for image-to-image generation.
405-
406-
[`AutoPipelineForImage2Image`] is a generic pipeline class that will be instantiated as one of the image-to-image
407-
pipeline classes in diffusers.
408-
409-
The pipeline type (for example [`StableDiffusionImg2ImgPipeline`]) is automatically selected when created with the
410-
`AutoPipelineForImage2Image.from_pretrained(pretrained_model_name_or_path)` or
411-
`AutoPipelineForImage2Image.from_pipe(pipeline)` class methods.
400+
[`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
401+
specific underlying pipeline class is automatically selected from either the
402+
[`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods.
412403
413-
This class cannot be instantiated using __init__() (throws an error).
404+
This class cannot be instantiated using `__init__()` (throws an error).
414405
415406
Class attributes:
416407
@@ -438,7 +429,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
438429
2. Find the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class
439430
name.
440431
441-
If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetImg2ImgPipeline object.
432+
If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetImg2ImgPipeline`]
433+
object.
442434
443435
The pipeline is set in evaluation mode (`model.eval()`) by default.
444436
@@ -537,10 +529,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
537529
Examples:
538530
539531
```py
540-
>>> from diffusers import AutoPipelineForText2Image
532+
>>> from diffusers import AutoPipelineForImage2Image
541533
542-
>>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
543-
>>> print(pipeline.__class__)
534+
>>> pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
535+
>>> image = pipeline(prompt, image).images[0]
544536
```
545537
"""
546538
config = cls.load_config(pretrained_model_or_path)
@@ -573,13 +565,14 @@ def from_pipe(cls, pipeline, **kwargs):
573565
Examples:
574566
575567
```py
576-
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImageToImage
568+
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
577569
578570
>>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
579571
... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
580572
... )
581573
582-
>>> pipe_i2i = AutoPipelineForImageToImage.from_pipe(pipe_t2i)
574+
>>> pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i)
575+
>>> image = pipe_i2i(prompt, image).images[0]
583576
```
584577
"""
585578

@@ -646,16 +639,11 @@ def from_pipe(cls, pipeline, **kwargs):
646639
class AutoPipelineForInpainting(ConfigMixin):
647640
r"""
648641
649-
AutoPipeline for inpainting generation.
650-
651-
[`AutoPipelineForInpainting`] is a generic pipeline class that will be instantiated as one of the inpainting
652-
pipeline class in diffusers.
642+
[`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
643+
specific underlying pipeline class is automatically selected from either the
644+
[`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods.
653645
654-
The pipeline type (for example [`IFInpaintingPipeline`]) is automatically selected when created with the
655-
AutoPipelineForInpainting.from_pretrained(pretrained_model_name_or_path) or
656-
AutoPipelineForInpainting.from_pipe(pipeline) class methods .
657-
658-
This class cannot be instantiated using __init__() (throws an error).
646+
This class cannot be instantiated using `__init__()` (throws an error).
659647
660648
Class attributes:
661649
@@ -682,7 +670,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
682670
config object
683671
2. Find the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name.
684672
685-
If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetInpaintPipeline object.
673+
If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetInpaintPipeline`]
674+
object.
686675
687676
The pipeline is set in evaluation mode (`model.eval()`) by default.
688677
@@ -781,10 +770,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
781770
Examples:
782771
783772
```py
784-
>>> from diffusers import AutoPipelineForText2Image
773+
>>> from diffusers import AutoPipelineForInpainting
785774
786-
>>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
787-
>>> print(pipeline.__class__)
775+
>>> pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5")
776+
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
788777
```
789778
"""
790779
config = cls.load_config(pretrained_model_or_path)
@@ -824,6 +813,7 @@ def from_pipe(cls, pipeline, **kwargs):
824813
... )
825814
826815
>>> pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_t2i)
816+
>>> image = pipe_inpaint(prompt, image=init_image, mask_image=mask_image).images[0]
827817
```
828818
"""
829819
original_config = dict(pipeline.config)

0 commit comments

Comments
 (0)