Skip to content

Commit fa9e35f

Browse files
Added input pretubation (#3292)
* Added input pretubation * Fixed spelling
1 parent 4bae76e commit fa9e35f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

examples/text_to_image/train_text_to_image.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
112112

113113
def parse_args():
114114
parser = argparse.ArgumentParser(description="Simple example of a training script.")
115+
parser.add_argument(
116+
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
117+
)
115118
parser.add_argument(
116119
"--pretrained_model_name_or_path",
117120
type=str,
@@ -801,15 +804,19 @@ def collate_fn(examples):
801804
noise += args.noise_offset * torch.randn(
802805
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
803806
)
804-
807+
if args.input_pertubation:
808+
new_noise = noise + args.input_pertubation * torch.randn_like(noise)
805809
bsz = latents.shape[0]
806810
# Sample a random timestep for each image
807811
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
808812
timesteps = timesteps.long()
809813

810814
# Add noise to the latents according to the noise magnitude at each timestep
811815
# (this is the forward diffusion process)
812-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
816+
if args.input_pertubation:
817+
noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
818+
else:
819+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
813820

814821
# Get the text embedding for conditioning
815822
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

0 commit comments

Comments
 (0)