14
14
15
15
import inspect
16
16
from dataclasses import dataclass
17
- from typing import Any , Callable , Dict , List , Optional , Union
17
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
18
18
19
19
import PIL .Image
20
20
import torch
75
75
76
76
>>> # Generate video
77
77
>>> generator = torch.Generator("cuda").manual_seed(0)
78
+ >>> # Text-only conditioning is also supported without the need to pass `conditions`
78
79
>>> video = pipe(
79
80
... conditions=[condition1, condition2],
80
81
... prompt=prompt,
@@ -223,7 +224,7 @@ def retrieve_latents(
223
224
224
225
class LTXConditionPipeline (DiffusionPipeline , FromSingleFileMixin , LTXVideoLoraLoaderMixin ):
225
226
r"""
226
- Pipeline for image-to-video generation.
227
+ Pipeline for text/ image/video -to-video generation.
227
228
228
229
Reference: https://github.com/Lightricks/LTX-Video
229
230
@@ -482,9 +483,6 @@ def check_inputs(
482
483
if conditions is not None and (image is not None or video is not None ):
483
484
raise ValueError ("If `conditions` is provided, `image` and `video` must not be provided." )
484
485
485
- if conditions is None and (image is None and video is None ):
486
- raise ValueError ("If `conditions` is not provided, `image` or `video` must be provided." )
487
-
488
486
if conditions is None :
489
487
if isinstance (image , list ) and isinstance (frame_index , list ) and len (image ) != len (frame_index ):
490
488
raise ValueError (
@@ -642,9 +640,9 @@ def add_noise_to_image_conditioning_latents(
642
640
643
641
def prepare_latents (
644
642
self ,
645
- conditions : List [torch .Tensor ],
646
- condition_strength : List [float ],
647
- condition_frame_index : List [int ],
643
+ conditions : Optional [ List [torch .Tensor ]] = None ,
644
+ condition_strength : Optional [ List [float ]] = None ,
645
+ condition_frame_index : Optional [ List [int ]] = None ,
648
646
batch_size : int = 1 ,
649
647
num_channels_latents : int = 128 ,
650
648
height : int = 512 ,
@@ -654,85 +652,88 @@ def prepare_latents(
654
652
generator : Optional [torch .Generator ] = None ,
655
653
device : Optional [torch .device ] = None ,
656
654
dtype : Optional [torch .dtype ] = None ,
657
- ) -> None :
655
+ ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor , int ] :
658
656
num_latent_frames = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
659
657
latent_height = height // self .vae_spatial_compression_ratio
660
658
latent_width = width // self .vae_spatial_compression_ratio
661
659
662
660
shape = (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
663
661
latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
664
662
665
- condition_latent_frames_mask = torch .zeros ((batch_size , num_latent_frames ), device = device , dtype = torch .float32 )
666
-
667
- extra_conditioning_latents = []
668
- extra_conditioning_video_ids = []
669
- extra_conditioning_mask = []
670
- extra_conditioning_num_latents = 0
671
- for data , strength , frame_index in zip (conditions , condition_strength , condition_frame_index ):
672
- condition_latents = retrieve_latents (self .vae .encode (data ), generator = generator )
673
- condition_latents = self ._normalize_latents (
674
- condition_latents , self .vae .latents_mean , self .vae .latents_std
675
- ).to (device , dtype = dtype )
676
-
677
- num_data_frames = data .size (2 )
678
- num_cond_frames = condition_latents .size (2 )
679
-
680
- if frame_index == 0 :
681
- latents [:, :, :num_cond_frames ] = torch .lerp (
682
- latents [:, :, :num_cond_frames ], condition_latents , strength
683
- )
684
- condition_latent_frames_mask [:, :num_cond_frames ] = strength
663
+ if len (conditions ) > 0 :
664
+ condition_latent_frames_mask = torch .zeros (
665
+ (batch_size , num_latent_frames ), device = device , dtype = torch .float32
666
+ )
685
667
686
- else :
687
- if num_data_frames > 1 :
688
- if num_cond_frames < num_prefix_latent_frames :
689
- raise ValueError (
690
- f"Number of latent frames must be at least { num_prefix_latent_frames } but got { num_data_frames } ."
691
- )
692
-
693
- if num_cond_frames > num_prefix_latent_frames :
694
- start_frame = frame_index // self .vae_temporal_compression_ratio + num_prefix_latent_frames
695
- end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
696
- latents [:, :, start_frame :end_frame ] = torch .lerp (
697
- latents [:, :, start_frame :end_frame ],
698
- condition_latents [:, :, num_prefix_latent_frames :],
699
- strength ,
700
- )
701
- condition_latent_frames_mask [:, start_frame :end_frame ] = strength
702
- condition_latents = condition_latents [:, :, :num_prefix_latent_frames ]
703
-
704
- noise = randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
705
- condition_latents = torch .lerp (noise , condition_latents , strength )
706
-
707
- condition_video_ids = self ._prepare_video_ids (
708
- batch_size ,
709
- condition_latents .size (2 ),
710
- latent_height ,
711
- latent_width ,
712
- patch_size = self .transformer_spatial_patch_size ,
713
- patch_size_t = self .transformer_temporal_patch_size ,
714
- device = device ,
715
- )
716
- condition_video_ids = self ._scale_video_ids (
717
- condition_video_ids ,
718
- scale_factor = self .vae_spatial_compression_ratio ,
719
- scale_factor_t = self .vae_temporal_compression_ratio ,
720
- frame_index = frame_index ,
721
- device = device ,
722
- )
723
- condition_latents = self ._pack_latents (
724
- condition_latents ,
725
- self .transformer_spatial_patch_size ,
726
- self .transformer_temporal_patch_size ,
727
- )
728
- condition_conditioning_mask = torch .full (
729
- condition_latents .shape [:2 ], strength , device = device , dtype = dtype
730
- )
668
+ extra_conditioning_latents = []
669
+ extra_conditioning_video_ids = []
670
+ extra_conditioning_mask = []
671
+ extra_conditioning_num_latents = 0
672
+ for data , strength , frame_index in zip (conditions , condition_strength , condition_frame_index ):
673
+ condition_latents = retrieve_latents (self .vae .encode (data ), generator = generator )
674
+ condition_latents = self ._normalize_latents (
675
+ condition_latents , self .vae .latents_mean , self .vae .latents_std
676
+ ).to (device , dtype = dtype )
677
+
678
+ num_data_frames = data .size (2 )
679
+ num_cond_frames = condition_latents .size (2 )
680
+
681
+ if frame_index == 0 :
682
+ latents [:, :, :num_cond_frames ] = torch .lerp (
683
+ latents [:, :, :num_cond_frames ], condition_latents , strength
684
+ )
685
+ condition_latent_frames_mask [:, :num_cond_frames ] = strength
686
+
687
+ else :
688
+ if num_data_frames > 1 :
689
+ if num_cond_frames < num_prefix_latent_frames :
690
+ raise ValueError (
691
+ f"Number of latent frames must be at least { num_prefix_latent_frames } but got { num_data_frames } ."
692
+ )
693
+
694
+ if num_cond_frames > num_prefix_latent_frames :
695
+ start_frame = frame_index // self .vae_temporal_compression_ratio + num_prefix_latent_frames
696
+ end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
697
+ latents [:, :, start_frame :end_frame ] = torch .lerp (
698
+ latents [:, :, start_frame :end_frame ],
699
+ condition_latents [:, :, num_prefix_latent_frames :],
700
+ strength ,
701
+ )
702
+ condition_latent_frames_mask [:, start_frame :end_frame ] = strength
703
+ condition_latents = condition_latents [:, :, :num_prefix_latent_frames ]
704
+
705
+ noise = randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
706
+ condition_latents = torch .lerp (noise , condition_latents , strength )
707
+
708
+ condition_video_ids = self ._prepare_video_ids (
709
+ batch_size ,
710
+ condition_latents .size (2 ),
711
+ latent_height ,
712
+ latent_width ,
713
+ patch_size = self .transformer_spatial_patch_size ,
714
+ patch_size_t = self .transformer_temporal_patch_size ,
715
+ device = device ,
716
+ )
717
+ condition_video_ids = self ._scale_video_ids (
718
+ condition_video_ids ,
719
+ scale_factor = self .vae_spatial_compression_ratio ,
720
+ scale_factor_t = self .vae_temporal_compression_ratio ,
721
+ frame_index = frame_index ,
722
+ device = device ,
723
+ )
724
+ condition_latents = self ._pack_latents (
725
+ condition_latents ,
726
+ self .transformer_spatial_patch_size ,
727
+ self .transformer_temporal_patch_size ,
728
+ )
729
+ condition_conditioning_mask = torch .full (
730
+ condition_latents .shape [:2 ], strength , device = device , dtype = dtype
731
+ )
731
732
732
- extra_conditioning_latents .append (condition_latents )
733
- extra_conditioning_video_ids .append (condition_video_ids )
734
- extra_conditioning_mask .append (condition_conditioning_mask )
735
- extra_conditioning_num_latents += condition_latents .size (1 )
733
+ extra_conditioning_latents .append (condition_latents )
734
+ extra_conditioning_video_ids .append (condition_video_ids )
735
+ extra_conditioning_mask .append (condition_conditioning_mask )
736
+ extra_conditioning_num_latents += condition_latents .size (1 )
736
737
737
738
video_ids = self ._prepare_video_ids (
738
739
batch_size ,
@@ -743,7 +744,10 @@ def prepare_latents(
743
744
patch_size = self .transformer_spatial_patch_size ,
744
745
device = device ,
745
746
)
746
- conditioning_mask = condition_latent_frames_mask .gather (1 , video_ids [:, 0 ])
747
+ if len (conditions ) > 0 :
748
+ conditioning_mask = condition_latent_frames_mask .gather (1 , video_ids [:, 0 ])
749
+ else :
750
+ conditioning_mask , extra_conditioning_num_latents = None , 0
747
751
video_ids = self ._scale_video_ids (
748
752
video_ids ,
749
753
scale_factor = self .vae_spatial_compression_ratio ,
@@ -755,7 +759,7 @@ def prepare_latents(
755
759
latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size
756
760
)
757
761
758
- if len (extra_conditioning_latents ) > 0 :
762
+ if len (conditions ) > 0 and len ( extra_conditioning_latents ) > 0 :
759
763
latents = torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
760
764
video_ids = torch .cat ([* extra_conditioning_video_ids , video_ids ], dim = 2 )
761
765
conditioning_mask = torch .cat ([* extra_conditioning_mask , conditioning_mask ], dim = 1 )
@@ -955,7 +959,7 @@ def __call__(
955
959
frame_index = [condition .frame_index for condition in conditions ]
956
960
image = [condition .image for condition in conditions ]
957
961
video = [condition .video for condition in conditions ]
958
- else :
962
+ elif image is not None or video is not None :
959
963
if not isinstance (image , list ):
960
964
image = [image ]
961
965
num_conditions = 1
@@ -999,32 +1003,34 @@ def __call__(
999
1003
vae_dtype = self .vae .dtype
1000
1004
1001
1005
conditioning_tensors = []
1002
- for condition_image , condition_video , condition_frame_index , condition_strength in zip (
1003
- image , video , frame_index , strength
1004
- ):
1005
- if condition_image is not None :
1006
- condition_tensor = (
1007
- self .video_processor .preprocess (condition_image , height , width )
1008
- .unsqueeze (2 )
1009
- .to (device , dtype = vae_dtype )
1010
- )
1011
- elif condition_video is not None :
1012
- condition_tensor = self .video_processor .preprocess_video (condition_video , height , width )
1013
- num_frames_input = condition_tensor .size (2 )
1014
- num_frames_output = self .trim_conditioning_sequence (
1015
- condition_frame_index , num_frames_input , num_frames
1016
- )
1017
- condition_tensor = condition_tensor [:, :, :num_frames_output ]
1018
- condition_tensor = condition_tensor .to (device , dtype = vae_dtype )
1019
- else :
1020
- raise ValueError ("Either `image` or `video` must be provided in the `LTXVideoCondition`." )
1021
-
1022
- if condition_tensor .size (2 ) % self .vae_temporal_compression_ratio != 1 :
1023
- raise ValueError (
1024
- f"Number of frames in the video must be of the form (k * { self .vae_temporal_compression_ratio } + 1) "
1025
- f"but got { condition_tensor .size (2 )} frames."
1026
- )
1027
- conditioning_tensors .append (condition_tensor )
1006
+ is_conditioning_image_or_video = image is not None or video is not None
1007
+ if is_conditioning_image_or_video :
1008
+ for condition_image , condition_video , condition_frame_index , condition_strength in zip (
1009
+ image , video , frame_index , strength
1010
+ ):
1011
+ if condition_image is not None :
1012
+ condition_tensor = (
1013
+ self .video_processor .preprocess (condition_image , height , width )
1014
+ .unsqueeze (2 )
1015
+ .to (device , dtype = vae_dtype )
1016
+ )
1017
+ elif condition_video is not None :
1018
+ condition_tensor = self .video_processor .preprocess_video (condition_video , height , width )
1019
+ num_frames_input = condition_tensor .size (2 )
1020
+ num_frames_output = self .trim_conditioning_sequence (
1021
+ condition_frame_index , num_frames_input , num_frames
1022
+ )
1023
+ condition_tensor = condition_tensor [:, :, :num_frames_output ]
1024
+ condition_tensor = condition_tensor .to (device , dtype = vae_dtype )
1025
+ else :
1026
+ raise ValueError ("Either `image` or `video` must be provided for conditioning." )
1027
+
1028
+ if condition_tensor .size (2 ) % self .vae_temporal_compression_ratio != 1 :
1029
+ raise ValueError (
1030
+ f"Number of frames in the video must be of the form (k * { self .vae_temporal_compression_ratio } + 1) "
1031
+ f"but got { condition_tensor .size (2 )} frames."
1032
+ )
1033
+ conditioning_tensors .append (condition_tensor )
1028
1034
1029
1035
# 4. Prepare latent variables
1030
1036
num_channels_latents = self .transformer .config .in_channels
@@ -1045,7 +1051,7 @@ def __call__(
1045
1051
video_coords = video_coords .float ()
1046
1052
video_coords [:, 0 ] = video_coords [:, 0 ] * (1.0 / frame_rate )
1047
1053
1048
- init_latents = latents .clone ()
1054
+ init_latents = latents .clone () if is_conditioning_image_or_video else None
1049
1055
1050
1056
if self .do_classifier_free_guidance :
1051
1057
video_coords = torch .cat ([video_coords , video_coords ], dim = 0 )
@@ -1065,15 +1071,15 @@ def __call__(
1065
1071
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
1066
1072
self ._num_timesteps = len (timesteps )
1067
1073
1068
- # 7 . Denoising loop
1074
+ # 6 . Denoising loop
1069
1075
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1070
1076
for i , t in enumerate (timesteps ):
1071
1077
if self .interrupt :
1072
1078
continue
1073
1079
1074
1080
self ._current_timestep = t
1075
1081
1076
- if image_cond_noise_scale > 0 :
1082
+ if image_cond_noise_scale > 0 and init_latents is not None :
1077
1083
# Add timestep-dependent noise to the hard-conditioning latents
1078
1084
# This helps with motion continuity, especially when conditioned on a single frame
1079
1085
latents = self .add_noise_to_image_conditioning_latents (
@@ -1086,16 +1092,18 @@ def __call__(
1086
1092
)
1087
1093
1088
1094
latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1089
- conditioning_mask_model_input = (
1090
- torch .cat ([conditioning_mask , conditioning_mask ])
1091
- if self .do_classifier_free_guidance
1092
- else conditioning_mask
1093
- )
1095
+ if is_conditioning_image_or_video :
1096
+ conditioning_mask_model_input = (
1097
+ torch .cat ([conditioning_mask , conditioning_mask ])
1098
+ if self .do_classifier_free_guidance
1099
+ else conditioning_mask
1100
+ )
1094
1101
latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1095
1102
1096
1103
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1097
1104
timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1098
- timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
1105
+ if is_conditioning_image_or_video :
1106
+ timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
1099
1107
1100
1108
noise_pred = self .transformer (
1101
1109
hidden_states = latent_model_input ,
@@ -1115,8 +1123,11 @@ def __call__(
1115
1123
denoised_latents = self .scheduler .step (
1116
1124
- noise_pred , t , latents , per_token_timesteps = timestep , return_dict = False
1117
1125
)[0 ]
1118
- tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask )).unsqueeze (- 1 )
1119
- latents = torch .where (tokens_to_denoise_mask , denoised_latents , latents )
1126
+ if is_conditioning_image_or_video :
1127
+ tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask )).unsqueeze (- 1 )
1128
+ latents = torch .where (tokens_to_denoise_mask , denoised_latents , latents )
1129
+ else :
1130
+ latents = denoised_latents
1120
1131
1121
1132
if callback_on_step_end is not None :
1122
1133
callback_kwargs = {}
@@ -1134,7 +1145,9 @@ def __call__(
1134
1145
if XLA_AVAILABLE :
1135
1146
xm .mark_step ()
1136
1147
1137
- latents = latents [:, extra_conditioning_num_latents :]
1148
+ if is_conditioning_image_or_video :
1149
+ latents = latents [:, extra_conditioning_num_latents :]
1150
+
1138
1151
latents = self ._unpack_latents (
1139
1152
latents ,
1140
1153
latent_num_frames ,
0 commit comments