1
1
# code adapted from https://github.com/exx8/differential-diffusion
2
2
3
3
import torch
4
- import inspect
5
4
6
5
class DifferentialDiffusion ():
7
6
@classmethod
@@ -13,82 +12,28 @@ def INPUT_TYPES(s):
13
12
CATEGORY = "_for_testing"
14
13
INIT = False
15
14
16
- @classmethod
17
- def IS_CHANGED (s , * args , ** kwargs ):
18
- DifferentialDiffusion .INIT = s .INIT = True
19
- return ""
20
-
21
- def __init__ (self ) -> None :
22
- DifferentialDiffusion .INIT = False
23
- self .sigmas : torch .Tensor = None
24
- self .thresholds : torch .Tensor = None
25
- self .mask_i = None
26
- self .valid_sigmas = False
27
- self .varying_sigmas_samplers = ["dpmpp_2s" , "dpmpp_sde" , "dpm_2" , "heun" , "restart" ]
28
-
29
15
def apply (self , model ):
30
16
model = model .clone ()
31
- model .model_options [ "denoise_mask_function" ] = self .forward
17
+ model .set_model_denoise_mask_function ( self .forward )
32
18
return (model ,)
33
-
34
- def init_sigmas (self , sigma : torch .Tensor , denoise_mask : torch .Tensor ):
35
- self .__init__ ()
36
- self .sigmas , sampler = find_outer_instance ("sigmas" , callback = get_sigmas_and_sampler ) or (None , "" )
37
- self .valid_sigmas = not ("sample_" not in sampler or any (s in sampler for s in self .varying_sigmas_samplers )) or "generic" in sampler
38
- if self .sigmas is None :
39
- self .sigmas = sigma [:1 ].repeat (2 )
40
- self .sigmas [- 1 ].zero_ ()
41
- self .sigmas_min = self .sigmas .min ()
42
- self .sigmas_max = self .sigmas .max ()
43
- self .thresholds = torch .linspace (1 , 0 , self .sigmas .shape [0 ], dtype = sigma .dtype , device = sigma .device )
44
- self .thresholds_min_len = self .thresholds .shape [0 ] - 1
45
- if self .valid_sigmas :
46
- thresholds = self .thresholds [:- 1 ].reshape (- 1 , 1 , 1 , 1 , 1 )
47
- mask = denoise_mask .unsqueeze (0 )
48
- mask = (mask >= thresholds ).to (denoise_mask .dtype )
49
- self .mask_i = iter (mask )
50
-
51
- def forward (self , sigma : torch .Tensor , denoise_mask : torch .Tensor ):
52
- if self .sigmas is None or DifferentialDiffusion .INIT :
53
- self .init_sigmas (sigma , denoise_mask )
54
- if self .valid_sigmas :
55
- try :
56
- return next (self .mask_i )
57
- except StopIteration :
58
- self .valid_sigmas = False
59
- if self .thresholds_min_len > 1 :
60
- nearest_idx = (self .sigmas - sigma [0 ]).abs ().argmin ()
61
- if not self .thresholds_min_len > nearest_idx :
62
- nearest_idx = - 2
63
- threshold = self .thresholds [nearest_idx ]
64
- else :
65
- threshold = (sigma [0 ] - self .sigmas_min ) / (self .sigmas_max - self .sigmas_min )
66
- return (denoise_mask >= threshold ).to (denoise_mask .dtype )
67
19
68
- def get_sigmas_and_sampler (frame , target ):
69
- found = frame .f_locals [target ]
70
- if isinstance (found , torch .Tensor ) and found [- 1 ] < 0.1 :
71
- return found , frame .f_code .co_name
72
- return False
20
+ def forward (self , sigma : torch .Tensor , denoise_mask : torch .Tensor , extra_options : dict ):
21
+ model = extra_options ["model" ]
22
+ step_sigmas = extra_options ["sigmas" ]
23
+ sigma_to = model .inner_model .model_sampling .sigma_min
24
+ if step_sigmas [- 1 ] > sigma_to :
25
+ sigma_to = step_sigmas [- 1 ]
26
+ sigma_from = step_sigmas [0 ]
27
+
28
+ ts_from = model .inner_model .model_sampling .timestep (sigma_from )
29
+ ts_to = model .inner_model .model_sampling .timestep (sigma_to )
30
+ current_ts = model .inner_model .model_sampling .timestep (sigma )
31
+
32
+ threshold = (current_ts - ts_to ) / (ts_from - ts_to )
33
+
34
+ return (denoise_mask >= threshold ).to (denoise_mask .dtype )
73
35
74
- def find_outer_instance (target : str , target_type = None , callback = None ):
75
- frame = inspect .currentframe ()
76
- i = 0
77
- while frame and i < 100 :
78
- if target in frame .f_locals :
79
- if callback is not None :
80
- res = callback (frame , target )
81
- if res :
82
- return res
83
- else :
84
- found = frame .f_locals [target ]
85
- if isinstance (found , target_type ):
86
- return found
87
- frame = frame .f_back
88
- i += 1
89
- return None
90
36
91
-
92
37
NODE_CLASS_MAPPINGS = {
93
38
"DifferentialDiffusion" : DifferentialDiffusion ,
94
39
}
0 commit comments