Skip to content

Commit 733b44a

Browse files
hlkysayakpaul
andauthored
[hybrid inference 🍯🐝] Add VAE encode (#11017)
* [hybrid inference 🍯🐝] Add VAE encode * _toctree: add vae encode * Add endpoints, tests * vae_encode docs * vae encode benchmarks * api reference * changelog * Update docs/source/en/hybrid_inference/overview.md Co-authored-by: Sayak Paul <[email protected]> * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 8b4f8ba commit 733b44a

File tree

8 files changed

+546
-22
lines changed

8 files changed

+546
-22
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
title: Overview
8282
- local: hybrid_inference/vae_decode
8383
title: VAE Decode
84+
- local: hybrid_inference/vae_encode
85+
title: VAE Encode
8486
- local: hybrid_inference/api_reference
8587
title: API Reference
8688
title: Hybrid Inference

docs/source/en/hybrid_inference/api_reference.md

+4
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
## Remote Decode
44

55
[[autodoc]] utils.remote_utils.remote_decode
6+
7+
## Remote Encode
8+
9+
[[autodoc]] utils.remote_utils.remote_encode

docs/source/en/hybrid_inference/overview.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
3636
## Available Models
3737

3838
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
39-
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
39+
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
4040
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
4141

4242
---
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
4646
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
4747
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
4848

49+
## Changelog
50+
51+
- March 10 2025: Added VAE encode
52+
- March 2 2025: Initial release with VAE decoding
53+
4954
## Contents
5055

51-
The documentation is organized into two sections:
56+
The documentation is organized into three sections:
5257

5358
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
59+
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
5460
* **API Reference** Dive into task-specific settings and parameters.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Getting Started: VAE Encode with Hybrid Inference
2+
3+
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
4+
5+
## Memory
6+
7+
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
8+
9+
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
10+
11+
<details><summary>SD v1.5</summary>
12+
13+
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
14+
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
15+
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
16+
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
17+
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
18+
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
19+
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
20+
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
21+
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
22+
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
23+
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
24+
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
25+
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
26+
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
27+
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
28+
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
29+
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
30+
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
31+
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
32+
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
33+
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
34+
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
35+
36+
37+
</details>
38+
39+
<details><summary>SDXL</summary>
40+
41+
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
42+
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
43+
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
44+
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
45+
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
46+
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
47+
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
48+
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
49+
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
50+
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
51+
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
52+
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
53+
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
54+
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
55+
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
56+
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
57+
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
58+
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
59+
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
60+
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
61+
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
62+
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
63+
64+
</details>
65+
66+
## Available VAEs
67+
68+
| | **Endpoint** | **Model** |
69+
|:-:|:-----------:|:--------:|
70+
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
71+
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
72+
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
73+
74+
75+
> [!TIP]
76+
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
77+
78+
79+
## Code
80+
81+
> [!TIP]
82+
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
83+
84+
85+
A helper method simplifies interacting with Hybrid Inference.
86+
87+
```python
88+
from diffusers.utils.remote_utils import remote_encode
89+
```
90+
91+
### Basic example
92+
93+
Let's encode an image, then decode it to demonstrate.
94+
95+
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
96+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
97+
</figure>
98+
99+
<details><summary>Code</summary>
100+
101+
```python
102+
from diffusers.utils import load_image
103+
from diffusers.utils.remote_utils import remote_decode
104+
105+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
106+
107+
latent = remote_encode(
108+
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
109+
scaling_factor=0.3611,
110+
shift_factor=0.1159,
111+
)
112+
113+
decoded = remote_decode(
114+
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
115+
tensor=latent,
116+
scaling_factor=0.3611,
117+
shift_factor=0.1159,
118+
)
119+
```
120+
121+
</details>
122+
123+
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
124+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
125+
</figure>
126+
127+
128+
### Generation
129+
130+
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
131+
132+
<details><summary>Code</summary>
133+
134+
```python
135+
import torch
136+
from diffusers import StableDiffusionImg2ImgPipeline
137+
from diffusers.utils import load_image
138+
from diffusers.utils.remote_utils import remote_decode, remote_encode
139+
140+
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
141+
"stable-diffusion-v1-5/stable-diffusion-v1-5",
142+
torch_dtype=torch.float16,
143+
variant="fp16",
144+
vae=None,
145+
).to("cuda")
146+
147+
init_image = load_image(
148+
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
149+
)
150+
init_image = init_image.resize((768, 512))
151+
152+
init_latent = remote_encode(
153+
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
154+
image=init_image,
155+
scaling_factor=0.18215,
156+
)
157+
158+
prompt = "A fantasy landscape, trending on artstation"
159+
latent = pipe(
160+
prompt=prompt,
161+
image=init_latent,
162+
strength=0.75,
163+
output_type="latent",
164+
).images
165+
166+
image = remote_decode(
167+
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
168+
tensor=latent,
169+
scaling_factor=0.18215,
170+
)
171+
image.save("fantasy_landscape.jpg")
172+
```
173+
174+
</details>
175+
176+
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
177+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
178+
</figure>
179+
180+
## Integrations
181+
182+
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
183+
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

src/diffusers/utils/constants.py

+11
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,14 @@
5656

5757
if USE_PEFT_BACKEND and _CHECK_PEFT:
5858
dep_version_check("peft")
59+
60+
61+
DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
62+
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
63+
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
64+
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
65+
66+
67+
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
68+
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
69+
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"

src/diffusers/utils/remote_utils.py

+97-6
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str:
5555
return "unknown"
5656

5757

58-
def check_inputs(
58+
def check_inputs_decode(
5959
endpoint: str,
6060
tensor: "torch.Tensor",
6161
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
@@ -89,7 +89,7 @@ def check_inputs(
8989
)
9090

9191

92-
def postprocess(
92+
def postprocess_decode(
9393
response: requests.Response,
9494
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
9595
output_type: Literal["mp4", "pil", "pt"] = "pil",
@@ -142,7 +142,7 @@ def postprocess(
142142
return output
143143

144144

145-
def prepare(
145+
def prepare_decode(
146146
tensor: "torch.Tensor",
147147
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
148148
do_scaling: bool = True,
@@ -293,7 +293,7 @@ def remote_decode(
293293
standard_warn=False,
294294
)
295295
output_tensor_type = "binary"
296-
check_inputs(
296+
check_inputs_decode(
297297
endpoint,
298298
tensor,
299299
processor,
@@ -309,7 +309,7 @@ def remote_decode(
309309
height,
310310
width,
311311
)
312-
kwargs = prepare(
312+
kwargs = prepare_decode(
313313
tensor=tensor,
314314
processor=processor,
315315
do_scaling=do_scaling,
@@ -324,11 +324,102 @@ def remote_decode(
324324
response = requests.post(endpoint, **kwargs)
325325
if not response.ok:
326326
raise RuntimeError(response.json())
327-
output = postprocess(
327+
output = postprocess_decode(
328328
response=response,
329329
processor=processor,
330330
output_type=output_type,
331331
return_type=return_type,
332332
partial_postprocess=partial_postprocess,
333333
)
334334
return output
335+
336+
337+
def check_inputs_encode(
338+
endpoint: str,
339+
image: Union["torch.Tensor", Image.Image],
340+
scaling_factor: Optional[float] = None,
341+
shift_factor: Optional[float] = None,
342+
):
343+
pass
344+
345+
346+
def postprocess_encode(
347+
response: requests.Response,
348+
):
349+
output_tensor = response.content
350+
parameters = response.headers
351+
shape = json.loads(parameters["shape"])
352+
dtype = parameters["dtype"]
353+
torch_dtype = DTYPE_MAP[dtype]
354+
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
355+
return output_tensor
356+
357+
358+
def prepare_encode(
359+
image: Union["torch.Tensor", Image.Image],
360+
scaling_factor: Optional[float] = None,
361+
shift_factor: Optional[float] = None,
362+
):
363+
headers = {}
364+
parameters = {}
365+
if scaling_factor is not None:
366+
parameters["scaling_factor"] = scaling_factor
367+
if shift_factor is not None:
368+
parameters["shift_factor"] = shift_factor
369+
if isinstance(image, torch.Tensor):
370+
data = safetensors.torch._tobytes(image, "tensor")
371+
parameters["shape"] = list(image.shape)
372+
parameters["dtype"] = str(image.dtype).split(".")[-1]
373+
else:
374+
buffer = io.BytesIO()
375+
image.save(buffer, format="PNG")
376+
data = buffer.getvalue()
377+
return {"data": data, "params": parameters, "headers": headers}
378+
379+
380+
def remote_encode(
381+
endpoint: str,
382+
image: Union["torch.Tensor", Image.Image],
383+
scaling_factor: Optional[float] = None,
384+
shift_factor: Optional[float] = None,
385+
) -> "torch.Tensor":
386+
"""
387+
Hugging Face Hybrid Inference that allow running VAE encode remotely.
388+
389+
Args:
390+
endpoint (`str`):
391+
Endpoint for Remote Decode.
392+
image (`torch.Tensor` or `PIL.Image.Image`):
393+
Image to be encoded.
394+
scaling_factor (`float`, *optional*):
395+
Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
396+
- SD v1: 0.18215
397+
- SD XL: 0.13025
398+
- Flux: 0.3611
399+
If `None`, input must be passed with scaling applied.
400+
shift_factor (`float`, *optional*):
401+
Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
402+
- Flux: 0.1159
403+
If `None`, input must be passed with scaling applied.
404+
405+
Returns:
406+
output (`torch.Tensor`).
407+
"""
408+
check_inputs_encode(
409+
endpoint,
410+
image,
411+
scaling_factor,
412+
shift_factor,
413+
)
414+
kwargs = prepare_encode(
415+
image=image,
416+
scaling_factor=scaling_factor,
417+
shift_factor=shift_factor,
418+
)
419+
response = requests.post(endpoint, **kwargs)
420+
if not response.ok:
421+
raise RuntimeError(response.json())
422+
output = postprocess_encode(
423+
response=response,
424+
)
425+
return output

0 commit comments

Comments
 (0)