Skip to content

[docs] AnimateDiff FreeNoise #9414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,79 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
</tr>
</table>

## Using FreeNoise

[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu.

FreeNoise is a sampling mechanism that allows the generation of longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper.

The currently supported AnimateDiff pipelines that can be used with FreeNoise are:
- AnimateDiffPipeline
- AnimateDiffControlNetPipeline
- AnimateDiffVideoToVideoPipeline
- AnimateDiffVideoToVideoControlNetPipeline

```python
import torch
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter
from diffusers.utils import export_to_video, load_image

# Load pipeline
dtype = torch.float16
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)

pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights(
"wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora"
)
pipe.set_adapters(["lcm_lora"], [0.8])

# Enable FreeNoise for long prompt generation
pipe.enable_free_noise(context_length=16, context_stride=4)
pipe.to("cuda")

# Can be a single prompt, or a dictionary with frame timesteps
prompt = {
0: "A caterpillar on a leaf, high quality, photorealistic",
40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic",
80: "A cocoon on a leaf, flowers in the backgrond, photorealistic",
120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic",
160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic",
200: "A beautiful butterfly, flying away in a forest, photorealistic",
240: "A cyberpunk butterfly, neon lights, glowing",
}
negative_prompt = "bad quality, worst quality, jpeg artifacts"

# Run inference
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=256,
guidance_scale=2.5,
num_inference_steps=10,
generator=torch.Generator("cpu").manual_seed(0),
)

# Save video
frames = output.frames[0]
export_to_video(frames, "output.mp4", fps=16)
```

#### FreeNoise memory savings

Since FreeNoise processes multiple frames together, there are parts in the modeling where the memory required exceeds that available on normal consumer GPUs. The main memory bottlenecks that we identified are spatial and temporal attention blocks, upsampling and downsampling blocks, resnet blocks and feed-forward layers. Since most of these blocks operate effectively only on the channel/embedding dimension, one can perform chunked inference across the batch dimensions. The batch dimension in AnimateDiff are either spatial (`[B x F, H x W, C]`) or temporal (`B x H x W, F, C`) in nature (note that it may seem counter-intuitive, but the batch dimension here are correct, because spatial blocks process across the `B x F` dimension while the temporal blocks process across the `B x H x W` dimension). We introduce a `SplitInferenceModule` that makes it easier to chunk across any dimension and perform inference. This saves a lot of memory but comes at the cost of requiring more time for inference.

```diff
# Load pipeline and adapters
# ...
+ pipe.enable_free_noise_split_inference()
+ pipe.unet.enable_forward_chunking(16)
```

The call to `pipe.enable_free_noise_split_inference` method accepts two parameters: `spatial_split_size` (defaults to `256`) and `temporal_split_size` (defaults to `16`). These can be configured based on how much VRAM you have available. A lower split size results in lower memory usage but slower inference, whereas a larger split size results in faster inference at the cost of more memory.

## Using `from_single_file` with the MotionAdapter

Expand Down
Loading