|
| 1 | +# T-GATE |
| 2 | + |
| 3 | +[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) accelerates inference for [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [PixArt](../api/pipelines/pixart), and [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like [DeepCache](./deepcache). |
| 4 | + |
| 5 | +Before you begin, make sure you install T-GATE. |
| 6 | + |
| 7 | +```bash |
| 8 | +pip install tgate |
| 9 | +pip install -U pytorch diffusers transformers accelerate DeepCache |
| 10 | +``` |
| 11 | + |
| 12 | + |
| 13 | +To use T-GATE with a pipeline, you need to use its corresponding loader. |
| 14 | + |
| 15 | +| Pipeline | T-GATE Loader | |
| 16 | +|---|---| |
| 17 | +| PixArt | TgatePixArtLoader | |
| 18 | +| Stable Diffusion XL | TgateSDXLLoader | |
| 19 | +| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader | |
| 20 | +| Stable Diffusion | TgateSDLoader | |
| 21 | +| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader | |
| 22 | + |
| 23 | +Next, create a `TgateLoader` with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the `tgate` method on the pipeline with a prompt, gate step, and the number of inference steps. |
| 24 | + |
| 25 | +Let's see how to enable this for several different pipelines. |
| 26 | + |
| 27 | +<hfoptions id="pipelines"> |
| 28 | +<hfoption id="PixArt"> |
| 29 | + |
| 30 | +Accelerate `PixArtAlphaPipeline` with T-GATE: |
| 31 | + |
| 32 | +```py |
| 33 | +import torch |
| 34 | +from diffusers import PixArtAlphaPipeline |
| 35 | +from tgate import TgatePixArtLoader |
| 36 | + |
| 37 | +pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) |
| 38 | +pipe = TgatePixArtLoader( |
| 39 | + pipe, |
| 40 | + gate_step=8, |
| 41 | + num_inference_steps=25, |
| 42 | +).to("cuda") |
| 43 | + |
| 44 | +image = pipe.tgate( |
| 45 | + "An alpaca made of colorful building blocks, cyberpunk.", |
| 46 | + gate_step=gate_step, |
| 47 | + num_inference_steps=inference_step, |
| 48 | +).images[0] |
| 49 | +``` |
| 50 | +</hfoption> |
| 51 | +<hfoption id="Stable Diffusion XL"> |
| 52 | + |
| 53 | +Accelerate `StableDiffusionXLPipeline` with T-GATE: |
| 54 | + |
| 55 | +```py |
| 56 | +import torch |
| 57 | +from diffusers import StableDiffusionXLPipeline |
| 58 | +from diffusers import DPMSolverMultistepScheduler |
| 59 | + |
| 60 | +pipe = StableDiffusionXLPipeline.from_pretrained( |
| 61 | + "stabilityai/stable-diffusion-xl-base-1.0", |
| 62 | + torch_dtype=torch.float16, |
| 63 | + variant="fp16", |
| 64 | + use_safetensors=True, |
| 65 | +) |
| 66 | +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
| 67 | + |
| 68 | +from tgate import TgateSDXLLoader |
| 69 | +gate_step = 10 |
| 70 | +inference_step = 25 |
| 71 | +pipe = TgateSDXLLoader( |
| 72 | + pipe, |
| 73 | + gate_step=gate_step, |
| 74 | + num_inference_steps=inference_step, |
| 75 | +).to("cuda") |
| 76 | + |
| 77 | +image = pipe.tgate( |
| 78 | + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.", |
| 79 | + gate_step=gate_step, |
| 80 | + num_inference_steps=inference_step |
| 81 | +).images[0] |
| 82 | +``` |
| 83 | +</hfoption> |
| 84 | +<hfoption id="StableDiffusionXL with DeepCache"> |
| 85 | + |
| 86 | +Accelerate `StableDiffusionXLPipeline` with [DeepCache](https://github.com/horseee/DeepCache) and T-GATE: |
| 87 | + |
| 88 | +```py |
| 89 | +import torch |
| 90 | +from diffusers import StableDiffusionXLPipeline |
| 91 | +from diffusers import DPMSolverMultistepScheduler |
| 92 | + |
| 93 | +pipe = StableDiffusionXLPipeline.from_pretrained( |
| 94 | + "stabilityai/stable-diffusion-xl-base-1.0", |
| 95 | + torch_dtype=torch.float16, |
| 96 | + variant="fp16", |
| 97 | + use_safetensors=True, |
| 98 | +) |
| 99 | +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
| 100 | + |
| 101 | +from tgate import TgateSDXLDeepCacheLoader |
| 102 | +gate_step = 10 |
| 103 | +inference_step = 25 |
| 104 | +pipe = TgateSDXLDeepCacheLoader( |
| 105 | + pipe, |
| 106 | + cache_interval=3, |
| 107 | + cache_branch_id=0, |
| 108 | +).to("cuda") |
| 109 | + |
| 110 | +image = pipe.tgate( |
| 111 | + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.", |
| 112 | + gate_step=gate_step, |
| 113 | + num_inference_steps=inference_step |
| 114 | +).images[0] |
| 115 | +``` |
| 116 | +</hfoption> |
| 117 | +<hfoption id="Latent Consistency Model"> |
| 118 | + |
| 119 | +Accelerate `latent-consistency/lcm-sdxl` with T-GATE: |
| 120 | + |
| 121 | +```py |
| 122 | +import torch |
| 123 | +from diffusers import StableDiffusionXLPipeline |
| 124 | +from diffusers import UNet2DConditionModel, LCMScheduler |
| 125 | +from diffusers import DPMSolverMultistepScheduler |
| 126 | + |
| 127 | +unet = UNet2DConditionModel.from_pretrained( |
| 128 | + "latent-consistency/lcm-sdxl", |
| 129 | + torch_dtype=torch.float16, |
| 130 | + variant="fp16", |
| 131 | +) |
| 132 | +pipe = StableDiffusionXLPipeline.from_pretrained( |
| 133 | + "stabilityai/stable-diffusion-xl-base-1.0", |
| 134 | + unet=unet, |
| 135 | + torch_dtype=torch.float16, |
| 136 | + variant="fp16", |
| 137 | +) |
| 138 | +pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
| 139 | + |
| 140 | +from tgate import TgateSDXLLoader |
| 141 | +gate_step = 1 |
| 142 | +inference_step = 4 |
| 143 | +pipe = TgateSDXLLoader( |
| 144 | + pipe, |
| 145 | + gate_step=gate_step, |
| 146 | + num_inference_steps=inference_step, |
| 147 | + lcm=True |
| 148 | +).to("cuda") |
| 149 | + |
| 150 | +image = pipe.tgate( |
| 151 | + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.", |
| 152 | + gate_step=gate_step, |
| 153 | + num_inference_steps=inference_step |
| 154 | +).images[0] |
| 155 | +``` |
| 156 | +</hfoption> |
| 157 | +</hfoptions> |
| 158 | + |
| 159 | +T-GATE also supports [`StableDiffusionPipeline`] and [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS). |
| 160 | + |
| 161 | +## Benchmarks |
| 162 | +| Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO | |
| 163 | +|-----------------------|----------|-----------|---------|---------------------------| |
| 164 | +| SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 | |
| 165 | +| SD-1.5 w/ T-GATE | 9.875T | 815.557M | 4.313s | 20.789 | |
| 166 | +| SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 | |
| 167 | +| SD-2.1 w/ T-GATE | 22.208T | 815.433 M | 9.878s | 19.940 | |
| 168 | +| SD-XL | 149.438T | 2.570B | 53.187s | 24.628 | |
| 169 | +| SD-XL w/ T-GATE | 84.438T | 2.024B | 27.932s | 22.738 | |
| 170 | +| Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 | |
| 171 | +| Pixart-Alpha w/ T-GATE | 65.318T | 462.585M | 37.867s | 35.825 | |
| 172 | +| DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 | |
| 173 | +| DeepCache w/ T-GATE | 43.868T | - | 14.666s | 23.999 | |
| 174 | +| LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 | |
| 175 | +| LCM w/ T-GATE | 11.171T | 2.024B | 3.533s | 25.028 | |
| 176 | +| LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 | |
| 177 | +| LCM w/ T-GATE | 7.623T | 462.585M | 4.543s | 37.048 | |
| 178 | + |
| 179 | +The latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with [calflops](https://github.com/MrYxJ/calculate-flops.pytorch), and the FID is calculated with [PytorchFID](https://github.com/mseitzer/pytorch-fid). |
0 commit comments