Skip to content

Commit 23c947a

Browse files
committed
automatically switch to 32-bit float VAE if the generated picture has NaNs.
1 parent 0e47c36 commit 23c947a

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
* speedup extra networks listing
3030
* added `[none]` filename token.
3131
* removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs)
32-
* add always_discard_next_to_last_sigma option to XYZ plot
32+
* add always_discard_next_to_last_sigma option to XYZ plot
33+
* automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag.
3334

3435
### Extensions and API:
3536
* api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop

modules/processing.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, Dict, List
1515

1616
import modules.sd_hijack
17-
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
17+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
1818
from modules.sd_hijack import model_hijack
1919
from modules.shared import opts, cmd_opts, state
2020
import modules.shared as shared
@@ -538,6 +538,40 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
538538
return x
539539

540540

541+
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
542+
samples = []
543+
544+
for i in range(batch.shape[0]):
545+
sample = decode_first_stage(model, batch[i:i + 1])[0]
546+
547+
if check_for_nans:
548+
try:
549+
devices.test_for_nans(sample, "vae")
550+
except devices.NansException as e:
551+
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
552+
raise e
553+
554+
errors.print_error_explanation(
555+
"A tensor with all NaNs was produced in VAE.\n"
556+
"Web UI will now convert VAE into 32-bit float and retry.\n"
557+
"To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
558+
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
559+
)
560+
561+
devices.dtype_vae = torch.float32
562+
model.first_stage_model.to(devices.dtype_vae)
563+
batch = batch.to(devices.dtype_vae)
564+
565+
sample = decode_first_stage(model, batch[i:i + 1])[0]
566+
567+
if target_device is not None:
568+
sample = sample.to(target_device)
569+
570+
samples.append(sample)
571+
572+
return samples
573+
574+
541575
def decode_first_stage(model, x):
542576
x = model.decode_first_stage(x.to(devices.dtype_vae))
543577

@@ -758,10 +792,7 @@ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
758792
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
759793
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
760794

761-
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
762-
for x in x_samples_ddim:
763-
devices.test_for_nans(x, "vae")
764-
795+
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
765796
x_samples_ddim = torch.stack(x_samples_ddim).float()
766797
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
767798

modules/shared.py

+1
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def list_samplers():
427427
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
428428
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
429429
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
430+
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
430431
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
431432
}))
432433

0 commit comments

Comments
 (0)