|
14 | 14 | from typing import Any, Dict, List
|
15 | 15 |
|
16 | 16 | import modules.sd_hijack
|
17 |
| -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet |
| 17 | +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors |
18 | 18 | from modules.sd_hijack import model_hijack
|
19 | 19 | from modules.shared import opts, cmd_opts, state
|
20 | 20 | import modules.shared as shared
|
@@ -538,6 +538,40 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
538 | 538 | return x
|
539 | 539 |
|
540 | 540 |
|
| 541 | +def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): |
| 542 | + samples = [] |
| 543 | + |
| 544 | + for i in range(batch.shape[0]): |
| 545 | + sample = decode_first_stage(model, batch[i:i + 1])[0] |
| 546 | + |
| 547 | + if check_for_nans: |
| 548 | + try: |
| 549 | + devices.test_for_nans(sample, "vae") |
| 550 | + except devices.NansException as e: |
| 551 | + if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: |
| 552 | + raise e |
| 553 | + |
| 554 | + errors.print_error_explanation( |
| 555 | + "A tensor with all NaNs was produced in VAE.\n" |
| 556 | + "Web UI will now convert VAE into 32-bit float and retry.\n" |
| 557 | + "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n" |
| 558 | + "To always start with 32-bit VAE, use --no-half-vae commandline flag." |
| 559 | + ) |
| 560 | + |
| 561 | + devices.dtype_vae = torch.float32 |
| 562 | + model.first_stage_model.to(devices.dtype_vae) |
| 563 | + batch = batch.to(devices.dtype_vae) |
| 564 | + |
| 565 | + sample = decode_first_stage(model, batch[i:i + 1])[0] |
| 566 | + |
| 567 | + if target_device is not None: |
| 568 | + sample = sample.to(target_device) |
| 569 | + |
| 570 | + samples.append(sample) |
| 571 | + |
| 572 | + return samples |
| 573 | + |
| 574 | + |
541 | 575 | def decode_first_stage(model, x):
|
542 | 576 | x = model.decode_first_stage(x.to(devices.dtype_vae))
|
543 | 577 |
|
@@ -758,10 +792,7 @@ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
758 | 792 | with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
759 | 793 | samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
760 | 794 |
|
761 |
| - x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] |
762 |
| - for x in x_samples_ddim: |
763 |
| - devices.test_for_nans(x, "vae") |
764 |
| - |
| 795 | + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) |
765 | 796 | x_samples_ddim = torch.stack(x_samples_ddim).float()
|
766 | 797 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
767 | 798 |
|
|
0 commit comments