Skip to content

Commit cef0e36

Browse files
authored
[RF inversion community pipeline] add eta_decay (#10199)
* add decay * add decay * style
1 parent ec9bfa9 commit cef0e36

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

examples/community/pipeline_flux_rf_inversion.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ def __call__(
648648
height: Optional[int] = None,
649649
width: Optional[int] = None,
650650
eta: float = 1.0,
651+
decay_eta: Optional[bool] = False,
652+
eta_decay_power: Optional[float] = 1.0,
651653
strength: float = 1.0,
652654
start_timestep: float = 0,
653655
stop_timestep: float = 0.25,
@@ -880,12 +882,9 @@ def __call__(
880882
v_t = -noise_pred
881883
v_t_cond = (y_0 - latents) / (1 - t_i)
882884
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883-
if start_timestep <= i < stop_timestep:
884-
# controlled vector field
885-
v_hat_t = v_t + eta * (v_t_cond - v_t)
886-
887-
else:
888-
v_hat_t = v_t
885+
if decay_eta:
886+
eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
887+
v_hat_t = v_t + eta_t * (v_t_cond - v_t)
889888

890889
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
891890
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])

0 commit comments

Comments
 (0)