Skip to content

Commit 12c1080

Browse files
Simplify differential diffusion code.
1 parent 727021b commit 12c1080

File tree

3 files changed

+23
-74
lines changed

3 files changed

+23
-74
lines changed

comfy/model_patcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_op
6767
def set_model_unet_function_wrapper(self, unet_wrapper_function):
6868
self.model_options["model_function_wrapper"] = unet_wrapper_function
6969

70+
def set_model_denoise_mask_function(self, denoise_mask_function):
71+
self.model_options["denoise_mask_function"] = denoise_mask_function
72+
7073
def set_model_patch(self, patch, name):
7174
to = self.model_options["transformer_options"]
7275
if "patches" not in to:

comfy/samplers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,14 @@ def forward(self, *args, **kwargs):
272272
return self.apply_model(*args, **kwargs)
273273

274274
class KSamplerX0Inpaint(torch.nn.Module):
275-
def __init__(self, model):
275+
def __init__(self, model, sigmas):
276276
super().__init__()
277277
self.inner_model = model
278+
self.sigmas = sigmas
278279
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
279280
if denoise_mask is not None:
280281
if "denoise_mask_function" in model_options:
281-
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask)
282+
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
282283
latent_mask = 1. - denoise_mask
283284
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
284285
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
@@ -528,7 +529,7 @@ def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
528529

529530
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
530531
extra_args["denoise_mask"] = denoise_mask
531-
model_k = KSamplerX0Inpaint(model_wrap)
532+
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
532533
model_k.latent_image = latent_image
533534
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
534535
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)

comfy_extras/nodes_differential_diffusion.py

Lines changed: 16 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# code adapted from https://github.com/exx8/differential-diffusion
22

33
import torch
4-
import inspect
54

65
class DifferentialDiffusion():
76
@classmethod
@@ -13,82 +12,28 @@ def INPUT_TYPES(s):
1312
CATEGORY = "_for_testing"
1413
INIT = False
1514

16-
@classmethod
17-
def IS_CHANGED(s, *args, **kwargs):
18-
DifferentialDiffusion.INIT = s.INIT = True
19-
return ""
20-
21-
def __init__(self) -> None:
22-
DifferentialDiffusion.INIT = False
23-
self.sigmas: torch.Tensor = None
24-
self.thresholds: torch.Tensor = None
25-
self.mask_i = None
26-
self.valid_sigmas = False
27-
self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"]
28-
2915
def apply(self, model):
3016
model = model.clone()
31-
model.model_options["denoise_mask_function"] = self.forward
17+
model.set_model_denoise_mask_function(self.forward)
3218
return (model,)
33-
34-
def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
35-
self.__init__()
36-
self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "")
37-
self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler
38-
if self.sigmas is None:
39-
self.sigmas = sigma[:1].repeat(2)
40-
self.sigmas[-1].zero_()
41-
self.sigmas_min = self.sigmas.min()
42-
self.sigmas_max = self.sigmas.max()
43-
self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device)
44-
self.thresholds_min_len = self.thresholds.shape[0] - 1
45-
if self.valid_sigmas:
46-
thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1)
47-
mask = denoise_mask.unsqueeze(0)
48-
mask = (mask >= thresholds).to(denoise_mask.dtype)
49-
self.mask_i = iter(mask)
50-
51-
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
52-
if self.sigmas is None or DifferentialDiffusion.INIT:
53-
self.init_sigmas(sigma, denoise_mask)
54-
if self.valid_sigmas:
55-
try:
56-
return next(self.mask_i)
57-
except StopIteration:
58-
self.valid_sigmas = False
59-
if self.thresholds_min_len > 1:
60-
nearest_idx = (self.sigmas - sigma[0]).abs().argmin()
61-
if not self.thresholds_min_len > nearest_idx:
62-
nearest_idx = -2
63-
threshold = self.thresholds[nearest_idx]
64-
else:
65-
threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min)
66-
return (denoise_mask >= threshold).to(denoise_mask.dtype)
6719

68-
def get_sigmas_and_sampler(frame, target):
69-
found = frame.f_locals[target]
70-
if isinstance(found, torch.Tensor) and found[-1] < 0.1:
71-
return found, frame.f_code.co_name
72-
return False
20+
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
21+
model = extra_options["model"]
22+
step_sigmas = extra_options["sigmas"]
23+
sigma_to = model.inner_model.model_sampling.sigma_min
24+
if step_sigmas[-1] > sigma_to:
25+
sigma_to = step_sigmas[-1]
26+
sigma_from = step_sigmas[0]
27+
28+
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
29+
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
30+
current_ts = model.inner_model.model_sampling.timestep(sigma)
31+
32+
threshold = (current_ts - ts_to) / (ts_from - ts_to)
33+
34+
return (denoise_mask >= threshold).to(denoise_mask.dtype)
7335

74-
def find_outer_instance(target: str, target_type=None, callback=None):
75-
frame = inspect.currentframe()
76-
i = 0
77-
while frame and i < 100:
78-
if target in frame.f_locals:
79-
if callback is not None:
80-
res = callback(frame, target)
81-
if res:
82-
return res
83-
else:
84-
found = frame.f_locals[target]
85-
if isinstance(found, target_type):
86-
return found
87-
frame = frame.f_back
88-
i += 1
89-
return None
9036

91-
9237
NODE_CLASS_MAPPINGS = {
9338
"DifferentialDiffusion": DifferentialDiffusion,
9439
}

0 commit comments

Comments
 (0)