@@ -109,14 +109,30 @@ def prompt_clean(text):
109
109
110
110
111
111
def retrieve_latents (
112
- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
112
+ encoder_output : torch .Tensor ,
113
+ latents_mean : torch .Tensor ,
114
+ latents_std : torch .Tensor ,
115
+ generator : Optional [torch .Generator ] = None ,
116
+ sample_mode : str = "sample" ,
113
117
):
114
118
if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
119
+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
120
+ encoder_output .latent_dist .logvar = torch .clamp (
121
+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
122
+ )
123
+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
124
+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
115
125
return encoder_output .latent_dist .sample (generator )
116
126
elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
127
+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
128
+ encoder_output .latent_dist .logvar = torch .clamp (
129
+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
130
+ )
131
+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
132
+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
117
133
return encoder_output .latent_dist .mode ()
118
134
elif hasattr (encoder_output , "latents" ):
119
- return encoder_output .latents
135
+ return ( encoder_output .latents - latents_mean ) * latents_std
120
136
else :
121
137
raise AttributeError ("Could not access latents of provided encoder_output" )
122
138
@@ -385,13 +401,6 @@ def prepare_latents(
385
401
)
386
402
video_condition = video_condition .to (device = device , dtype = dtype )
387
403
388
- if isinstance (generator , list ):
389
- latent_condition = [retrieve_latents (self .vae .encode (video_condition ), g ) for g in generator ]
390
- latents = latent_condition = torch .cat (latent_condition )
391
- else :
392
- latent_condition = retrieve_latents (self .vae .encode (video_condition ), generator )
393
- latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
394
-
395
404
latents_mean = (
396
405
torch .tensor (self .vae .config .latents_mean )
397
406
.view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
@@ -401,7 +410,14 @@ def prepare_latents(
401
410
latents .device , latents .dtype
402
411
)
403
412
404
- latent_condition = (latent_condition - latents_mean ) * latents_std
413
+ if isinstance (generator , list ):
414
+ latent_condition = [
415
+ retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , g ) for g in generator
416
+ ]
417
+ latent_condition = torch .cat (latent_condition )
418
+ else :
419
+ latent_condition = retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , generator )
420
+ latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
405
421
406
422
mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
407
423
mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
0 commit comments