Skip to content

Commit 13e4849

Browse files
authored
[LTX0.9.5] Refactor LTXConditionPipeline for text-only conditioning (#11174)
* Refactor `LTXConditionPipeline` to add text-only conditioning * style * up * Refactor `LTXConditionPipeline` to streamline condition handling and improve clarity * Improve condition checks * Simplify latents handling based on conditioning type * Refactor rope_interpolation_scale preparation for clarity and efficiency * Update LTXConditionPipeline docstring to clarify supported input types * Add LTX Video 0.9.5 model to documentation * Clarify documentation to indicate support for text-only conditioning without passing `conditions` * refactor: comment out unused parameters in LTXConditionPipeline * fix: restore previously commented parameters in LTXConditionPipeline * fix: remove unused parameters from LTXConditionPipeline * refactor: remove unnecessary lines in LTXConditionPipeline
1 parent 94f2c48 commit 13e4849

File tree

2 files changed

+133
-119
lines changed

2 files changed

+133
-119
lines changed

Diff for: docs/source/en/api/pipelines/ltx_video.md

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Available models:
3232
|:-------------:|:-----------------:|
3333
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
3434
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
35+
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
3536

3637
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
3738

Diff for: src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

+132-119
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
from dataclasses import dataclass
17-
from typing import Any, Callable, Dict, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import PIL.Image
2020
import torch
@@ -75,6 +75,7 @@
7575
7676
>>> # Generate video
7777
>>> generator = torch.Generator("cuda").manual_seed(0)
78+
>>> # Text-only conditioning is also supported without the need to pass `conditions`
7879
>>> video = pipe(
7980
... conditions=[condition1, condition2],
8081
... prompt=prompt,
@@ -223,7 +224,7 @@ def retrieve_latents(
223224

224225
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
225226
r"""
226-
Pipeline for image-to-video generation.
227+
Pipeline for text/image/video-to-video generation.
227228
228229
Reference: https://github.com/Lightricks/LTX-Video
229230
@@ -482,9 +483,6 @@ def check_inputs(
482483
if conditions is not None and (image is not None or video is not None):
483484
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
484485

485-
if conditions is None and (image is None and video is None):
486-
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
487-
488486
if conditions is None:
489487
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
490488
raise ValueError(
@@ -642,9 +640,9 @@ def add_noise_to_image_conditioning_latents(
642640

643641
def prepare_latents(
644642
self,
645-
conditions: List[torch.Tensor],
646-
condition_strength: List[float],
647-
condition_frame_index: List[int],
643+
conditions: Optional[List[torch.Tensor]] = None,
644+
condition_strength: Optional[List[float]] = None,
645+
condition_frame_index: Optional[List[int]] = None,
648646
batch_size: int = 1,
649647
num_channels_latents: int = 128,
650648
height: int = 512,
@@ -654,85 +652,88 @@ def prepare_latents(
654652
generator: Optional[torch.Generator] = None,
655653
device: Optional[torch.device] = None,
656654
dtype: Optional[torch.dtype] = None,
657-
) -> None:
655+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
658656
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
659657
latent_height = height // self.vae_spatial_compression_ratio
660658
latent_width = width // self.vae_spatial_compression_ratio
661659

662660
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
663661
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
664662

665-
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
666-
667-
extra_conditioning_latents = []
668-
extra_conditioning_video_ids = []
669-
extra_conditioning_mask = []
670-
extra_conditioning_num_latents = 0
671-
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
672-
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
673-
condition_latents = self._normalize_latents(
674-
condition_latents, self.vae.latents_mean, self.vae.latents_std
675-
).to(device, dtype=dtype)
676-
677-
num_data_frames = data.size(2)
678-
num_cond_frames = condition_latents.size(2)
679-
680-
if frame_index == 0:
681-
latents[:, :, :num_cond_frames] = torch.lerp(
682-
latents[:, :, :num_cond_frames], condition_latents, strength
683-
)
684-
condition_latent_frames_mask[:, :num_cond_frames] = strength
663+
if len(conditions) > 0:
664+
condition_latent_frames_mask = torch.zeros(
665+
(batch_size, num_latent_frames), device=device, dtype=torch.float32
666+
)
685667

686-
else:
687-
if num_data_frames > 1:
688-
if num_cond_frames < num_prefix_latent_frames:
689-
raise ValueError(
690-
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
691-
)
692-
693-
if num_cond_frames > num_prefix_latent_frames:
694-
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
695-
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
696-
latents[:, :, start_frame:end_frame] = torch.lerp(
697-
latents[:, :, start_frame:end_frame],
698-
condition_latents[:, :, num_prefix_latent_frames:],
699-
strength,
700-
)
701-
condition_latent_frames_mask[:, start_frame:end_frame] = strength
702-
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
703-
704-
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
705-
condition_latents = torch.lerp(noise, condition_latents, strength)
706-
707-
condition_video_ids = self._prepare_video_ids(
708-
batch_size,
709-
condition_latents.size(2),
710-
latent_height,
711-
latent_width,
712-
patch_size=self.transformer_spatial_patch_size,
713-
patch_size_t=self.transformer_temporal_patch_size,
714-
device=device,
715-
)
716-
condition_video_ids = self._scale_video_ids(
717-
condition_video_ids,
718-
scale_factor=self.vae_spatial_compression_ratio,
719-
scale_factor_t=self.vae_temporal_compression_ratio,
720-
frame_index=frame_index,
721-
device=device,
722-
)
723-
condition_latents = self._pack_latents(
724-
condition_latents,
725-
self.transformer_spatial_patch_size,
726-
self.transformer_temporal_patch_size,
727-
)
728-
condition_conditioning_mask = torch.full(
729-
condition_latents.shape[:2], strength, device=device, dtype=dtype
730-
)
668+
extra_conditioning_latents = []
669+
extra_conditioning_video_ids = []
670+
extra_conditioning_mask = []
671+
extra_conditioning_num_latents = 0
672+
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
673+
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
674+
condition_latents = self._normalize_latents(
675+
condition_latents, self.vae.latents_mean, self.vae.latents_std
676+
).to(device, dtype=dtype)
677+
678+
num_data_frames = data.size(2)
679+
num_cond_frames = condition_latents.size(2)
680+
681+
if frame_index == 0:
682+
latents[:, :, :num_cond_frames] = torch.lerp(
683+
latents[:, :, :num_cond_frames], condition_latents, strength
684+
)
685+
condition_latent_frames_mask[:, :num_cond_frames] = strength
686+
687+
else:
688+
if num_data_frames > 1:
689+
if num_cond_frames < num_prefix_latent_frames:
690+
raise ValueError(
691+
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
692+
)
693+
694+
if num_cond_frames > num_prefix_latent_frames:
695+
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
696+
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
697+
latents[:, :, start_frame:end_frame] = torch.lerp(
698+
latents[:, :, start_frame:end_frame],
699+
condition_latents[:, :, num_prefix_latent_frames:],
700+
strength,
701+
)
702+
condition_latent_frames_mask[:, start_frame:end_frame] = strength
703+
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
704+
705+
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
706+
condition_latents = torch.lerp(noise, condition_latents, strength)
707+
708+
condition_video_ids = self._prepare_video_ids(
709+
batch_size,
710+
condition_latents.size(2),
711+
latent_height,
712+
latent_width,
713+
patch_size=self.transformer_spatial_patch_size,
714+
patch_size_t=self.transformer_temporal_patch_size,
715+
device=device,
716+
)
717+
condition_video_ids = self._scale_video_ids(
718+
condition_video_ids,
719+
scale_factor=self.vae_spatial_compression_ratio,
720+
scale_factor_t=self.vae_temporal_compression_ratio,
721+
frame_index=frame_index,
722+
device=device,
723+
)
724+
condition_latents = self._pack_latents(
725+
condition_latents,
726+
self.transformer_spatial_patch_size,
727+
self.transformer_temporal_patch_size,
728+
)
729+
condition_conditioning_mask = torch.full(
730+
condition_latents.shape[:2], strength, device=device, dtype=dtype
731+
)
731732

732-
extra_conditioning_latents.append(condition_latents)
733-
extra_conditioning_video_ids.append(condition_video_ids)
734-
extra_conditioning_mask.append(condition_conditioning_mask)
735-
extra_conditioning_num_latents += condition_latents.size(1)
733+
extra_conditioning_latents.append(condition_latents)
734+
extra_conditioning_video_ids.append(condition_video_ids)
735+
extra_conditioning_mask.append(condition_conditioning_mask)
736+
extra_conditioning_num_latents += condition_latents.size(1)
736737

737738
video_ids = self._prepare_video_ids(
738739
batch_size,
@@ -743,7 +744,10 @@ def prepare_latents(
743744
patch_size=self.transformer_spatial_patch_size,
744745
device=device,
745746
)
746-
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
747+
if len(conditions) > 0:
748+
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
749+
else:
750+
conditioning_mask, extra_conditioning_num_latents = None, 0
747751
video_ids = self._scale_video_ids(
748752
video_ids,
749753
scale_factor=self.vae_spatial_compression_ratio,
@@ -755,7 +759,7 @@ def prepare_latents(
755759
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
756760
)
757761

758-
if len(extra_conditioning_latents) > 0:
762+
if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
759763
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
760764
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
761765
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
@@ -955,7 +959,7 @@ def __call__(
955959
frame_index = [condition.frame_index for condition in conditions]
956960
image = [condition.image for condition in conditions]
957961
video = [condition.video for condition in conditions]
958-
else:
962+
elif image is not None or video is not None:
959963
if not isinstance(image, list):
960964
image = [image]
961965
num_conditions = 1
@@ -999,32 +1003,34 @@ def __call__(
9991003
vae_dtype = self.vae.dtype
10001004

10011005
conditioning_tensors = []
1002-
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
1003-
image, video, frame_index, strength
1004-
):
1005-
if condition_image is not None:
1006-
condition_tensor = (
1007-
self.video_processor.preprocess(condition_image, height, width)
1008-
.unsqueeze(2)
1009-
.to(device, dtype=vae_dtype)
1010-
)
1011-
elif condition_video is not None:
1012-
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
1013-
num_frames_input = condition_tensor.size(2)
1014-
num_frames_output = self.trim_conditioning_sequence(
1015-
condition_frame_index, num_frames_input, num_frames
1016-
)
1017-
condition_tensor = condition_tensor[:, :, :num_frames_output]
1018-
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
1019-
else:
1020-
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
1021-
1022-
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
1023-
raise ValueError(
1024-
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
1025-
f"but got {condition_tensor.size(2)} frames."
1026-
)
1027-
conditioning_tensors.append(condition_tensor)
1006+
is_conditioning_image_or_video = image is not None or video is not None
1007+
if is_conditioning_image_or_video:
1008+
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
1009+
image, video, frame_index, strength
1010+
):
1011+
if condition_image is not None:
1012+
condition_tensor = (
1013+
self.video_processor.preprocess(condition_image, height, width)
1014+
.unsqueeze(2)
1015+
.to(device, dtype=vae_dtype)
1016+
)
1017+
elif condition_video is not None:
1018+
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
1019+
num_frames_input = condition_tensor.size(2)
1020+
num_frames_output = self.trim_conditioning_sequence(
1021+
condition_frame_index, num_frames_input, num_frames
1022+
)
1023+
condition_tensor = condition_tensor[:, :, :num_frames_output]
1024+
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
1025+
else:
1026+
raise ValueError("Either `image` or `video` must be provided for conditioning.")
1027+
1028+
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
1029+
raise ValueError(
1030+
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
1031+
f"but got {condition_tensor.size(2)} frames."
1032+
)
1033+
conditioning_tensors.append(condition_tensor)
10281034

10291035
# 4. Prepare latent variables
10301036
num_channels_latents = self.transformer.config.in_channels
@@ -1045,7 +1051,7 @@ def __call__(
10451051
video_coords = video_coords.float()
10461052
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
10471053

1048-
init_latents = latents.clone()
1054+
init_latents = latents.clone() if is_conditioning_image_or_video else None
10491055

10501056
if self.do_classifier_free_guidance:
10511057
video_coords = torch.cat([video_coords, video_coords], dim=0)
@@ -1065,15 +1071,15 @@ def __call__(
10651071
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
10661072
self._num_timesteps = len(timesteps)
10671073

1068-
# 7. Denoising loop
1074+
# 6. Denoising loop
10691075
with self.progress_bar(total=num_inference_steps) as progress_bar:
10701076
for i, t in enumerate(timesteps):
10711077
if self.interrupt:
10721078
continue
10731079

10741080
self._current_timestep = t
10751081

1076-
if image_cond_noise_scale > 0:
1082+
if image_cond_noise_scale > 0 and init_latents is not None:
10771083
# Add timestep-dependent noise to the hard-conditioning latents
10781084
# This helps with motion continuity, especially when conditioned on a single frame
10791085
latents = self.add_noise_to_image_conditioning_latents(
@@ -1086,16 +1092,18 @@ def __call__(
10861092
)
10871093

10881094
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1089-
conditioning_mask_model_input = (
1090-
torch.cat([conditioning_mask, conditioning_mask])
1091-
if self.do_classifier_free_guidance
1092-
else conditioning_mask
1093-
)
1095+
if is_conditioning_image_or_video:
1096+
conditioning_mask_model_input = (
1097+
torch.cat([conditioning_mask, conditioning_mask])
1098+
if self.do_classifier_free_guidance
1099+
else conditioning_mask
1100+
)
10941101
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
10951102

10961103
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
10971104
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1098-
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
1105+
if is_conditioning_image_or_video:
1106+
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
10991107

11001108
noise_pred = self.transformer(
11011109
hidden_states=latent_model_input,
@@ -1115,8 +1123,11 @@ def __call__(
11151123
denoised_latents = self.scheduler.step(
11161124
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
11171125
)[0]
1118-
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1119-
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1126+
if is_conditioning_image_or_video:
1127+
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1128+
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1129+
else:
1130+
latents = denoised_latents
11201131

11211132
if callback_on_step_end is not None:
11221133
callback_kwargs = {}
@@ -1134,7 +1145,9 @@ def __call__(
11341145
if XLA_AVAILABLE:
11351146
xm.mark_step()
11361147

1137-
latents = latents[:, extra_conditioning_num_latents:]
1148+
if is_conditioning_image_or_video:
1149+
latents = latents[:, extra_conditioning_num_latents:]
1150+
11381151
latents = self._unpack_latents(
11391152
latents,
11401153
latent_num_frames,

0 commit comments

Comments
 (0)