Skip to content

Commit a517f66

Browse files
a-r-r-o-wDN6patrickvonplaten
authored
AnimateDiff Video to Video (#6328)
* begin animatediff img2video and video2video * revert animatediff to original implementation * add img2video as pipeline * update * add vid2vid pipeline * update imports * update * remove copied from line for check_inputs * update * update examples * add multi-batch support * fix __init__.py files * move img2vid to community * update community readme and examples * fix * make fix-copies * add vid2vid batch params * apply suggestions from review Co-Authored-By: Dhruv Nair <[email protected]> * add test for animatediff vid2vid * torch.stack -> torch.cat Co-Authored-By: Dhruv Nair <[email protected]> * make style * docs for vid2vid * update * fix prepare_latents * fix docs * remove img2vid * update README to :main * remove slow test * refactor pipeline output * update docs * update docs * merge community readme from :main * final fix i promise * add support for url in animatediff example * update example * update callbacks to latest implementation * Update src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py Co-authored-by: Patrick von Platen <[email protected]> * fix merge * Apply suggestions from code review * remove callback and callback_steps as suggested in review * Update tests/pipelines/animatediff/test_animatediff_video2video.py Co-authored-by: Patrick von Platen <[email protected]> * fix import error caused due to unet refactor in #6630 * fix numpy import error after tensor2vid refactor in #6626 * make fix-copies * fix numpy error * fix progress bar test --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 16748d1 commit a517f66

File tree

10 files changed

+1402
-12
lines changed

10 files changed

+1402
-12
lines changed

docs/source/en/api/pipelines/animatediff.md

+111
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@ The abstract of the paper is the following:
2525
| Pipeline | Tasks | Demo
2626
|---|---|:---:|
2727
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
28+
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
2829

2930
## Available checkpoints
3031

3132
Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5.
3233

3334
## Usage example
3435

36+
### AnimateDiffPipeline
37+
3538
AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
3639

3740
The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
@@ -98,6 +101,114 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
98101

99102
</Tip>
100103

104+
### AnimateDiffVideoToVideoPipeline
105+
106+
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
107+
108+
```python
109+
import imageio
110+
import requests
111+
import torch
112+
from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
113+
from diffusers.utils import export_to_gif
114+
from io import BytesIO
115+
from PIL import Image
116+
117+
# Load the motion adapter
118+
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
119+
# load SD 1.5 based finetuned model
120+
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
121+
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
122+
scheduler = DDIMScheduler.from_pretrained(
123+
model_id,
124+
subfolder="scheduler",
125+
clip_sample=False,
126+
timestep_spacing="linspace",
127+
beta_schedule="linear",
128+
steps_offset=1,
129+
)
130+
pipe.scheduler = scheduler
131+
132+
# enable memory savings
133+
pipe.enable_vae_slicing()
134+
pipe.enable_model_cpu_offload()
135+
136+
# helper function to load videos
137+
def load_video(file_path: str):
138+
images = []
139+
140+
if file_path.startswith(('http://', 'https://')):
141+
# If the file_path is a URL
142+
response = requests.get(file_path)
143+
response.raise_for_status()
144+
content = BytesIO(response.content)
145+
vid = imageio.get_reader(content)
146+
else:
147+
# Assuming it's a local file path
148+
vid = imageio.get_reader(file_path)
149+
150+
for frame in vid:
151+
pil_image = Image.fromarray(frame)
152+
images.append(pil_image)
153+
154+
return images
155+
156+
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
157+
158+
output = pipe(
159+
video = video,
160+
prompt="panda playing a guitar, on a boat, in the ocean, high quality",
161+
negative_prompt="bad quality, worse quality",
162+
guidance_scale=7.5,
163+
num_inference_steps=25,
164+
strength=0.5,
165+
generator=torch.Generator("cpu").manual_seed(42),
166+
)
167+
frames = output.frames[0]
168+
export_to_gif(frames, "animation.gif")
169+
```
170+
171+
Here are some sample outputs:
172+
173+
<table>
174+
<tr>
175+
<th align=center>Source Video</th>
176+
<th align=center>Output Video</th>
177+
</tr>
178+
<tr>
179+
<td align=center>
180+
raccoon playing a guitar
181+
<br />
182+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
183+
alt="racoon playing a guitar"
184+
style="width: 300px;" />
185+
</td>
186+
<td align=center>
187+
panda playing a guitar
188+
<br/>
189+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-1.gif"
190+
alt="panda playing a guitar"
191+
style="width: 300px;" />
192+
</td>
193+
</tr>
194+
<tr>
195+
<td align=center>
196+
closeup of margot robbie, fireworks in the background, high quality
197+
<br />
198+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-2.gif"
199+
alt="closeup of margot robbie, fireworks in the background, high quality"
200+
style="width: 300px;" />
201+
</td>
202+
<td align=center>
203+
closeup of tony stark, robert downey jr, fireworks
204+
<br/>
205+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-2.gif"
206+
alt="closeup of tony stark, robert downey jr, fireworks"
207+
style="width: 300px;" />
208+
</td>
209+
</tr>
210+
</table>
211+
101212
## Using Motion LoRAs
102213

103214
Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.

src/diffusers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@
208208
"AmusedInpaintPipeline",
209209
"AmusedPipeline",
210210
"AnimateDiffPipeline",
211+
"AnimateDiffVideoToVideoPipeline",
211212
"AudioLDM2Pipeline",
212213
"AudioLDM2ProjectionModel",
213214
"AudioLDM2UNet2DConditionModel",
@@ -569,6 +570,7 @@
569570
AmusedInpaintPipeline,
570571
AmusedPipeline,
571572
AnimateDiffPipeline,
573+
AnimateDiffVideoToVideoPipeline,
572574
AudioLDM2Pipeline,
573575
AudioLDM2ProjectionModel,
574576
AudioLDM2UNet2DConditionModel,

src/diffusers/pipelines/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@
109109
]
110110
)
111111
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
112-
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
112+
_import_structure["animatediff"] = [
113+
"AnimateDiffPipeline",
114+
"AnimateDiffVideoToVideoPipeline",
115+
]
113116
_import_structure["audioldm"] = ["AudioLDMPipeline"]
114117
_import_structure["audioldm2"] = [
115118
"AudioLDM2Pipeline",
@@ -341,7 +344,7 @@
341344
from ..utils.dummy_torch_and_transformers_objects import *
342345
else:
343346
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
344-
from .animatediff import AnimateDiffPipeline
347+
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
345348
from .audioldm import AudioLDMPipeline
346349
from .audioldm2 import (
347350
AudioLDM2Pipeline,

src/diffusers/pipelines/animatediff/__init__.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
_dummy_objects = {}
14-
_import_structure = {}
14+
_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]}
1515

1616
try:
1717
if not (is_transformers_available() and is_torch_available()):
@@ -21,7 +21,8 @@
2121

2222
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2323
else:
24-
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
24+
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
25+
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
2526

2627
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2728
try:
@@ -31,7 +32,9 @@
3132
from ...utils.dummy_torch_and_transformers_objects import *
3233

3334
else:
34-
from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
35+
from .pipeline_animatediff import AnimateDiffPipeline
36+
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
37+
from .pipeline_output import AnimateDiffPipelineOutput
3538

3639
else:
3740
import sys

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import inspect
1616
import math
17-
from dataclasses import dataclass
1817
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1918

2019
import numpy as np
@@ -37,7 +36,6 @@
3736
)
3837
from ...utils import (
3938
USE_PEFT_BACKEND,
40-
BaseOutput,
4139
deprecate,
4240
logging,
4341
replace_example_docstring,
@@ -46,6 +44,7 @@
4644
)
4745
from ...utils.torch_utils import randn_tensor
4846
from ..pipeline_utils import DiffusionPipeline
47+
from .pipeline_output import AnimateDiffPipelineOutput
4948

5049

5150
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -153,11 +152,6 @@ def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> tor
153152
return x_mixed
154153

155154

156-
@dataclass
157-
class AnimateDiffPipelineOutput(BaseOutput):
158-
frames: Union[torch.Tensor, np.ndarray]
159-
160-
161155
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
162156
r"""
163157
Pipeline for text-to-video generation.

0 commit comments

Comments
 (0)