Skip to content

Commit 4ea9f89

Browse files
authored
Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)
* Wan Pipeline scaling fix, type hint warning, multi generator fix * Apply suggestions from code review
1 parent 733b44a commit 4ea9f89

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,30 @@ def prompt_clean(text):
109109

110110

111111
def retrieve_latents(
112-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
112+
encoder_output: torch.Tensor,
113+
latents_mean: torch.Tensor,
114+
latents_std: torch.Tensor,
115+
generator: Optional[torch.Generator] = None,
116+
sample_mode: str = "sample",
113117
):
114118
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
119+
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
120+
encoder_output.latent_dist.logvar = torch.clamp(
121+
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
122+
)
123+
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
124+
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
115125
return encoder_output.latent_dist.sample(generator)
116126
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
127+
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
128+
encoder_output.latent_dist.logvar = torch.clamp(
129+
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
130+
)
131+
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
132+
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
117133
return encoder_output.latent_dist.mode()
118134
elif hasattr(encoder_output, "latents"):
119-
return encoder_output.latents
135+
return (encoder_output.latents - latents_mean) * latents_std
120136
else:
121137
raise AttributeError("Could not access latents of provided encoder_output")
122138

@@ -385,13 +401,6 @@ def prepare_latents(
385401
)
386402
video_condition = video_condition.to(device=device, dtype=dtype)
387403

388-
if isinstance(generator, list):
389-
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
390-
latents = latent_condition = torch.cat(latent_condition)
391-
else:
392-
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
393-
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
394-
395404
latents_mean = (
396405
torch.tensor(self.vae.config.latents_mean)
397406
.view(1, self.vae.config.z_dim, 1, 1, 1)
@@ -401,7 +410,14 @@ def prepare_latents(
401410
latents.device, latents.dtype
402411
)
403412

404-
latent_condition = (latent_condition - latents_mean) * latents_std
413+
if isinstance(generator, list):
414+
latent_condition = [
415+
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
416+
]
417+
latent_condition = torch.cat(latent_condition)
418+
else:
419+
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
420+
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
405421

406422
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
407423
mask_lat_size[:, :, list(range(1, num_frames))] = 0

0 commit comments

Comments
 (0)