Skip to content

Commit a68503f

Browse files
[Docs] Add TGATE in section optimization (#7639)
* Create tgate.md * Update _toctree.yml * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <[email protected]> * Update tgate.md * Update tgate.md --------- Co-authored-by: Steven Liu <[email protected]>
1 parent 9d50f7e commit a68503f

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

Diff for: docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@
172172
title: Token merging
173173
- local: optimization/deepcache
174174
title: DeepCache
175+
- local: optimization/tgate
176+
title: TGATE
175177
title: General optimizations
176178
- sections:
177179
- local: using-diffusers/stable_diffusion_jax_how_to

Diff for: docs/source/en/optimization/tgate.md

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# T-GATE
2+
3+
[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) accelerates inference for [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [PixArt](../api/pipelines/pixart), and [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like [DeepCache](./deepcache).
4+
5+
Before you begin, make sure you install T-GATE.
6+
7+
```bash
8+
pip install tgate
9+
pip install -U pytorch diffusers transformers accelerate DeepCache
10+
```
11+
12+
13+
To use T-GATE with a pipeline, you need to use its corresponding loader.
14+
15+
| Pipeline | T-GATE Loader |
16+
|---|---|
17+
| PixArt | TgatePixArtLoader |
18+
| Stable Diffusion XL | TgateSDXLLoader |
19+
| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |
20+
| Stable Diffusion | TgateSDLoader |
21+
| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |
22+
23+
Next, create a `TgateLoader` with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the `tgate` method on the pipeline with a prompt, gate step, and the number of inference steps.
24+
25+
Let's see how to enable this for several different pipelines.
26+
27+
<hfoptions id="pipelines">
28+
<hfoption id="PixArt">
29+
30+
Accelerate `PixArtAlphaPipeline` with T-GATE:
31+
32+
```py
33+
import torch
34+
from diffusers import PixArtAlphaPipeline
35+
from tgate import TgatePixArtLoader
36+
37+
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
38+
pipe = TgatePixArtLoader(
39+
pipe,
40+
gate_step=8,
41+
num_inference_steps=25,
42+
).to("cuda")
43+
44+
image = pipe.tgate(
45+
"An alpaca made of colorful building blocks, cyberpunk.",
46+
gate_step=gate_step,
47+
num_inference_steps=inference_step,
48+
).images[0]
49+
```
50+
</hfoption>
51+
<hfoption id="Stable Diffusion XL">
52+
53+
Accelerate `StableDiffusionXLPipeline` with T-GATE:
54+
55+
```py
56+
import torch
57+
from diffusers import StableDiffusionXLPipeline
58+
from diffusers import DPMSolverMultistepScheduler
59+
60+
pipe = StableDiffusionXLPipeline.from_pretrained(
61+
"stabilityai/stable-diffusion-xl-base-1.0",
62+
torch_dtype=torch.float16,
63+
variant="fp16",
64+
use_safetensors=True,
65+
)
66+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
67+
68+
from tgate import TgateSDXLLoader
69+
gate_step = 10
70+
inference_step = 25
71+
pipe = TgateSDXLLoader(
72+
pipe,
73+
gate_step=gate_step,
74+
num_inference_steps=inference_step,
75+
).to("cuda")
76+
77+
image = pipe.tgate(
78+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
79+
gate_step=gate_step,
80+
num_inference_steps=inference_step
81+
).images[0]
82+
```
83+
</hfoption>
84+
<hfoption id="StableDiffusionXL with DeepCache">
85+
86+
Accelerate `StableDiffusionXLPipeline` with [DeepCache](https://github.com/horseee/DeepCache) and T-GATE:
87+
88+
```py
89+
import torch
90+
from diffusers import StableDiffusionXLPipeline
91+
from diffusers import DPMSolverMultistepScheduler
92+
93+
pipe = StableDiffusionXLPipeline.from_pretrained(
94+
"stabilityai/stable-diffusion-xl-base-1.0",
95+
torch_dtype=torch.float16,
96+
variant="fp16",
97+
use_safetensors=True,
98+
)
99+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
100+
101+
from tgate import TgateSDXLDeepCacheLoader
102+
gate_step = 10
103+
inference_step = 25
104+
pipe = TgateSDXLDeepCacheLoader(
105+
pipe,
106+
cache_interval=3,
107+
cache_branch_id=0,
108+
).to("cuda")
109+
110+
image = pipe.tgate(
111+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
112+
gate_step=gate_step,
113+
num_inference_steps=inference_step
114+
).images[0]
115+
```
116+
</hfoption>
117+
<hfoption id="Latent Consistency Model">
118+
119+
Accelerate `latent-consistency/lcm-sdxl` with T-GATE:
120+
121+
```py
122+
import torch
123+
from diffusers import StableDiffusionXLPipeline
124+
from diffusers import UNet2DConditionModel, LCMScheduler
125+
from diffusers import DPMSolverMultistepScheduler
126+
127+
unet = UNet2DConditionModel.from_pretrained(
128+
"latent-consistency/lcm-sdxl",
129+
torch_dtype=torch.float16,
130+
variant="fp16",
131+
)
132+
pipe = StableDiffusionXLPipeline.from_pretrained(
133+
"stabilityai/stable-diffusion-xl-base-1.0",
134+
unet=unet,
135+
torch_dtype=torch.float16,
136+
variant="fp16",
137+
)
138+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
139+
140+
from tgate import TgateSDXLLoader
141+
gate_step = 1
142+
inference_step = 4
143+
pipe = TgateSDXLLoader(
144+
pipe,
145+
gate_step=gate_step,
146+
num_inference_steps=inference_step,
147+
lcm=True
148+
).to("cuda")
149+
150+
image = pipe.tgate(
151+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
152+
gate_step=gate_step,
153+
num_inference_steps=inference_step
154+
).images[0]
155+
```
156+
</hfoption>
157+
</hfoptions>
158+
159+
T-GATE also supports [`StableDiffusionPipeline`] and [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS).
160+
161+
## Benchmarks
162+
| Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO |
163+
|-----------------------|----------|-----------|---------|---------------------------|
164+
| SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
165+
| SD-1.5 w/ T-GATE | 9.875T | 815.557M | 4.313s | 20.789 |
166+
| SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
167+
| SD-2.1 w/ T-GATE | 22.208T | 815.433 M | 9.878s | 19.940 |
168+
| SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
169+
| SD-XL w/ T-GATE | 84.438T | 2.024B | 27.932s | 22.738 |
170+
| Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
171+
| Pixart-Alpha w/ T-GATE | 65.318T | 462.585M | 37.867s | 35.825 |
172+
| DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
173+
| DeepCache w/ T-GATE | 43.868T | - | 14.666s | 23.999 |
174+
| LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 |
175+
| LCM w/ T-GATE | 11.171T | 2.024B | 3.533s | 25.028 |
176+
| LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 |
177+
| LCM w/ T-GATE | 7.623T | 462.585M | 4.543s | 37.048 |
178+
179+
The latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with [calflops](https://github.com/MrYxJ/calculate-flops.pytorch), and the FID is calculated with [PytorchFID](https://github.com/mseitzer/pytorch-fid).

0 commit comments

Comments
 (0)