File tree 1 file changed +5
-6
lines changed
1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -648,6 +648,8 @@ def __call__(
648
648
height : Optional [int ] = None ,
649
649
width : Optional [int ] = None ,
650
650
eta : float = 1.0 ,
651
+ decay_eta : Optional [bool ] = False ,
652
+ eta_decay_power : Optional [float ] = 1.0 ,
651
653
strength : float = 1.0 ,
652
654
start_timestep : float = 0 ,
653
655
stop_timestep : float = 0.25 ,
@@ -880,12 +882,9 @@ def __call__(
880
882
v_t = - noise_pred
881
883
v_t_cond = (y_0 - latents ) / (1 - t_i )
882
884
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883
- if start_timestep <= i < stop_timestep :
884
- # controlled vector field
885
- v_hat_t = v_t + eta * (v_t_cond - v_t )
886
-
887
- else :
888
- v_hat_t = v_t
885
+ if decay_eta :
886
+ eta_t = eta_t * (1 - i / num_inference_steps ) ** eta_decay_power # Decay eta over the loop
887
+ v_hat_t = v_t + eta_t * (v_t_cond - v_t )
889
888
890
889
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
891
890
latents = latents + v_hat_t * (sigmas [i ] - sigmas [i + 1 ])
You can’t perform that action at this time.
0 commit comments