Skip to content

Commit 629b756

Browse files
authored
Merge pull request #288 from testIgnor/main
Add Custom Tiled VAE Sizes to NeverOOM
2 parents 542450b + 20362cf commit 629b756

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

extensions-builtin/sd_forge_neveroom/scripts/forge_never_oom.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import gradio as gr
2+
import torch
3+
import modules.devices as devices
24

35
from modules import scripts
46
from ldm_patched.modules import model_management
@@ -17,23 +19,54 @@ def title(self):
1719
def show(self, is_img2img):
1820
return scripts.AlwaysVisible
1921

22+
"""
23+
The following two functions are pulled directly from
24+
pkuliyi2015/multidiffusion-upscaler-for-automatic1111
25+
"""
26+
def get_rcmd_enc_tsize(self):
27+
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
28+
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
29+
if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072
30+
elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048
31+
elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536
32+
else: ENCODER_TILE_SIZE = 960
33+
else: ENCODER_TILE_SIZE = 512
34+
return ENCODER_TILE_SIZE
35+
36+
def get_rcmd_dec_tsize(self):
37+
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
38+
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
39+
if total_memory > 30*1000: DECODER_TILE_SIZE = 256
40+
elif total_memory > 16*1000: DECODER_TILE_SIZE = 192
41+
elif total_memory > 12*1000: DECODER_TILE_SIZE = 128
42+
elif total_memory > 8*1000: DECODER_TILE_SIZE = 96
43+
else: DECODER_TILE_SIZE = 64
44+
else: DECODER_TILE_SIZE = 64
45+
return DECODER_TILE_SIZE
46+
2047
def ui(self, *args, **kwargs):
2148
with gr.Accordion(open=False, label=self.title()):
2249
unet_enabled = gr.Checkbox(label='Enabled for UNet (always maximize offload)', value=False)
2350
vae_enabled = gr.Checkbox(label='Enabled for VAE (always tiled)', value=False)
24-
return unet_enabled, vae_enabled
51+
encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=self.get_rcmd_enc_tsize())
52+
decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=self.get_rcmd_dec_tsize())
53+
return unet_enabled, vae_enabled, encoder_tile_size, decoder_tile_size
2554

2655
def process(self, p, *script_args, **kwargs):
27-
unet_enabled, vae_enabled = script_args
56+
unet_enabled, vae_enabled, encoder_tile_size, decoder_tile_size = script_args
2857

2958
if unet_enabled:
3059
print('NeverOOM Enabled for UNet (always maximize offload)')
3160

3261
if vae_enabled:
3362
print('NeverOOM Enabled for VAE (always tiled)')
63+
print('With tile sizes')
64+
print(f'Encode:\t x:{encoder_tile_size}\t y:{encoder_tile_size}')
65+
print(f'Decode:\t x:{decoder_tile_size}\t y:{decoder_tile_size}')
3466

3567
model_management.VAE_ALWAYS_TILED = vae_enabled
36-
68+
model_management.VAE_ENCODE_TILE_SIZE_X = model_management.VAE_ENCODE_TILE_SIZE_Y = encoder_tile_size
69+
model_management.VAE_DECODE_TILE_SIZE_X = model_management.VAE_DECODE_TILE_SIZE_Y = decoder_tile_size
3770
if self.previous_unet_enabled != unet_enabled:
3871
model_management.unload_all_models()
3972
if unet_enabled:

ldm_patched/modules/model_management.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def is_nvidia():
204204
VAE_DTYPE = torch.float32
205205

206206
VAE_ALWAYS_TILED = False
207+
# include VAE tile size parameters
208+
VAE_ENCODE_TILE_SIZE_X = 512
209+
VAE_ENCODE_TILE_SIZE_Y = 512
210+
VAE_DECODE_TILE_SIZE_X = 64
211+
VAE_DECODE_TILE_SIZE_Y = 64
207212

208213
def set_fp16_accumulation_if_available():
209214
if args.allow_fp16_accumulation:
@@ -506,7 +511,7 @@ def load_models_gpu(models, memory_required=0):
506511

507512
if vram_set_state == VRAMState.NO_VRAM:
508513
async_kept_memory = 0
509-
514+
510515
loaded_model.model_load(async_kept_memory)
511516
current_loaded_models.insert(0, loaded_model)
512517

@@ -692,7 +697,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
692697
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
693698
else:
694699
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
695-
700+
696701
def sage_attention_enabled():
697702
return args.use_sage_attention
698703

@@ -758,7 +763,7 @@ def get_free_memory(dev=None, torch_free_too=False):
758763
return (mem_free_total, mem_free_torch)
759764
else:
760765
return mem_free_total
761-
766+
762767
def mac_version():
763768
try:
764769
return tuple(int(n) for n in platform.mac_ver()[0].split("."))

ldm_patched/modules/sd.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def load_clip_weights(model, sd):
5858

5959
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
6060
model_flag = type(model.model).__name__ if model is not None else 'default'
61-
61+
6262
# Only build key maps for components we'll actually use
6363
key_map = {}
6464
if model is not None and strength_model != 0:
@@ -72,7 +72,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
7272

7373
# Load LoRA weights
7474
loaded = ldm_patched.modules.lora.load_lora(lora, key_map)
75-
75+
7676
# Only clone and patch if we have relevant weights
7777
if model is not None and strength_model != 0:
7878
new_modelpatcher = model.clone()
@@ -261,7 +261,11 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
261261

262262
def decode_inner(self, samples_in):
263263
if model_management.VAE_ALWAYS_TILED:
264-
return self.decode_tiled(samples_in).to(self.output_device)
264+
return self.decode_tiled(
265+
samples_in,
266+
tile_x = model_management.VAE_DECODE_TILE_SIZE_X,
267+
tile_y = model_management.VAE_DECODE_TILE_SIZE_Y
268+
).to(self.output_device)
265269

266270
try:
267271
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@@ -295,8 +299,12 @@ def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
295299

296300
def encode_inner(self, pixel_samples):
297301
if model_management.VAE_ALWAYS_TILED:
298-
return self.encode_tiled(pixel_samples)
299-
302+
return self.encode_tiled(
303+
pixel_samples,
304+
tile_x = model_management.VAE_ENCODE_TILE_SIZE_X,
305+
tile_y = model_management.VAE_ENCODE_TILE_SIZE_Y
306+
)
307+
300308
regulation = self.patcher.model_options.get("model_vae_regulation", None)
301309

302310
pixel_samples = pixel_samples.movedim(-1,1)

0 commit comments

Comments
 (0)