@@ -112,6 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
112
112
113
113
def parse_args ():
114
114
parser = argparse .ArgumentParser (description = "Simple example of a training script." )
115
+ parser .add_argument (
116
+ "--input_pertubation" , type = float , default = 0 , help = "The scale of input pretubation. Recommended 0.1."
117
+ )
115
118
parser .add_argument (
116
119
"--pretrained_model_name_or_path" ,
117
120
type = str ,
@@ -801,15 +804,19 @@ def collate_fn(examples):
801
804
noise += args .noise_offset * torch .randn (
802
805
(latents .shape [0 ], latents .shape [1 ], 1 , 1 ), device = latents .device
803
806
)
804
-
807
+ if args .input_pertubation :
808
+ new_noise = noise + args .input_pertubation * torch .randn_like (noise )
805
809
bsz = latents .shape [0 ]
806
810
# Sample a random timestep for each image
807
811
timesteps = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device )
808
812
timesteps = timesteps .long ()
809
813
810
814
# Add noise to the latents according to the noise magnitude at each timestep
811
815
# (this is the forward diffusion process)
812
- noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
816
+ if args .input_pertubation :
817
+ noisy_latents = noise_scheduler .add_noise (latents , new_noise , timesteps )
818
+ else :
819
+ noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
813
820
814
821
# Get the text embedding for conditioning
815
822
encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
0 commit comments