Skip to content

Commit 3ba2f92

Browse files
Merge pull request CompVis#36 from enzymezoo-code/inpainting_1.0
Updating ipynb with colab-convert
2 parents 5776e29 + 9a62dce commit 3ba2f92

File tree

1 file changed

+113
-84
lines changed

1 file changed

+113
-84
lines changed

Deforum_Stable_Diffusion.ipynb

+113-84
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,10 @@
164164
"def add_noise(sample: torch.Tensor, noise_amt: float):\n",
165165
" return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
166166
"\n",
167-
"def get_output_folder(output_path,batch_folder=None):\n",
168-
" yearMonth = time.strftime('%Y-%m/')\n",
169-
" out_path = os.path.join(output_path,yearMonth)\n",
167+
"def get_output_folder(output_path, batch_folder):\n",
168+
" out_path = os.path.join(output_path,time.strftime('%Y-%m/'))\n",
170169
" if batch_folder != \"\":\n",
171-
" out_path = os.path.join(out_path,batch_folder)\n",
172-
" # we will also make sure the path suffix is a slash if linux and a backslash if windows\n",
173-
" if out_path[-1] != os.path.sep:\n",
174-
" out_path += os.path.sep\n",
170+
" out_path = os.path.join(out_path, batch_folder)\n",
175171
" os.makedirs(out_path, exist_ok=True)\n",
176172
" return out_path\n",
177173
"\n",
@@ -203,14 +199,19 @@
203199
" mask = torch.from_numpy(mask)\n",
204200
" return mask\n",
205201
"\n",
206-
"def maintain_colors(prev_img, color_match_sample, hsv=False):\n",
207-
" if hsv:\n",
202+
"def maintain_colors(prev_img, color_match_sample, mode):\n",
203+
" if mode == 'Match Frame 0 RGB':\n",
204+
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
205+
" elif mode == 'Match Frame 0 HSV':\n",
208206
" prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n",
209207
" color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n",
210208
" matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n",
211209
" return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n",
212-
" else:\n",
213-
" return match_histograms(prev_img, color_match_sample, multichannel=True)\n",
210+
" else: # Match Frame 0 LAB\n",
211+
" prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)\n",
212+
" color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)\n",
213+
" matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)\n",
214+
" return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)\n",
214215
"\n",
215216
"\n",
216217
"def make_callback(sampler_name, dynamic_threshold=None, static_threshold=None, mask=None, init_latent=None, sigmas=None, sampler=None, masked_noise_modifier=1.0): \n",
@@ -330,57 +331,56 @@
330331
" with torch.no_grad():\n",
331332
" with precision_scope(\"cuda\"):\n",
332333
" with model.ema_scope():\n",
333-
" for n in range(args.n_samples):\n",
334-
" for prompts in data:\n",
335-
" uc = None\n",
336-
" if args.scale != 1.0:\n",
337-
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
338-
" if isinstance(prompts, tuple):\n",
339-
" prompts = list(prompts)\n",
340-
" c = model.get_learned_conditioning(prompts)\n",
341-
"\n",
342-
" if args.init_c != None:\n",
343-
" c = args.init_c\n",
344-
"\n",
345-
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
346-
" samples = sampler_fn(\n",
347-
" c=c, \n",
348-
" uc=uc, \n",
349-
" args=args, \n",
350-
" model_wrap=model_wrap, \n",
351-
" init_latent=init_latent, \n",
352-
" t_enc=t_enc, \n",
353-
" device=device, \n",
354-
" cb=callback)\n",
334+
" for prompts in data:\n",
335+
" uc = None\n",
336+
" if args.scale != 1.0:\n",
337+
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
338+
" if isinstance(prompts, tuple):\n",
339+
" prompts = list(prompts)\n",
340+
" c = model.get_learned_conditioning(prompts)\n",
341+
"\n",
342+
" if args.init_c != None:\n",
343+
" c = args.init_c\n",
344+
"\n",
345+
" if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n",
346+
" samples = sampler_fn(\n",
347+
" c=c, \n",
348+
" uc=uc, \n",
349+
" args=args, \n",
350+
" model_wrap=model_wrap, \n",
351+
" init_latent=init_latent, \n",
352+
" t_enc=t_enc, \n",
353+
" device=device, \n",
354+
" cb=callback)\n",
355+
" else:\n",
356+
" # args.sampler == 'plms' or args.sampler == 'ddim':\n",
357+
" if init_latent is not None and args.strength > 0:\n",
358+
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
355359
" else:\n",
356-
" # args.sampler == 'plms' or args.sampler == 'ddim':\n",
357-
" if init_latent is not None and args.strength > 0:\n",
358-
" z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n",
359-
" else:\n",
360-
" z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
361-
" samples = sampler.decode(z_enc, \n",
362-
" c, \n",
363-
" t_enc, \n",
364-
" unconditional_guidance_scale=args.scale,\n",
365-
" unconditional_conditioning=uc,\n",
366-
" img_callback=callback)\n",
367-
"\n",
368-
" if return_latent:\n",
369-
" results.append(samples.clone())\n",
370-
"\n",
371-
" x_samples = model.decode_first_stage(samples)\n",
372-
" if return_sample:\n",
373-
" results.append(x_samples.clone())\n",
374-
"\n",
375-
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
376-
"\n",
377-
" if return_c:\n",
378-
" results.append(c.clone())\n",
379-
"\n",
380-
" for x_sample in x_samples:\n",
381-
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
382-
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
383-
" results.append(image)\n",
360+
" z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n",
361+
" samples = sampler.decode(z_enc, \n",
362+
" c, \n",
363+
" t_enc, \n",
364+
" unconditional_guidance_scale=args.scale,\n",
365+
" unconditional_conditioning=uc,\n",
366+
" img_callback=callback)\n",
367+
"\n",
368+
" if return_latent:\n",
369+
" results.append(samples.clone())\n",
370+
"\n",
371+
" x_samples = model.decode_first_stage(samples)\n",
372+
" if return_sample:\n",
373+
" results.append(x_samples.clone())\n",
374+
"\n",
375+
" x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n",
376+
"\n",
377+
" if return_c:\n",
378+
" results.append(c.clone())\n",
379+
"\n",
380+
" for x_sample in x_samples:\n",
381+
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
382+
" image = Image.fromarray(x_sample.astype(np.uint8))\n",
383+
" results.append(image)\n",
384384
" return results\n",
385385
"\n",
386386
"def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n",
@@ -569,13 +569,14 @@
569569
" scale_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
570570
"\n",
571571
" #@markdown ####**Coherence:**\n",
572-
" color_coherence = 'MatchFrame0' #@param ['None', 'MatchFrame0'] {type:'string'}\n",
572+
" color_coherence = 'Match Frame 0 HSV' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n",
573573
"\n",
574574
" #@markdown ####**Video Input:**\n",
575575
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
576576
" extract_nth_frame = 1#@param {type:\"number\"}\n",
577577
"\n",
578578
" #@markdown ####**Interpolation:**\n",
579+
" interpolate_key_frames = False #@param {type:\"boolean\"}\n",
579580
" interpolate_x_frames = 4 #@param {type:\"number\"}\n",
580581
"\n",
581582
" return locals()\n",
@@ -657,6 +658,7 @@
657658
"id": "2ujwkGZTcGev"
658659
},
659660
"source": [
661+
"\n",
660662
"prompts = [\n",
661663
" \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
662664
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
@@ -665,9 +667,9 @@
665667
"\n",
666668
"animation_prompts = {\n",
667669
" 0: \"a beautiful apple, trending on Artstation\",\n",
668-
" 10: \"a beautiful banana, trending on Artstation\",\n",
669-
" 100: \"a beautiful coconut, trending on Artstation\",\n",
670-
" 101: \"a beautiful durian, trending on Artstation\",\n",
670+
" 20: \"a beautiful banana, trending on Artstation\",\n",
671+
" 30: \"a beautiful coconut, trending on Artstation\",\n",
672+
" 40: \"a beautiful durian, trending on Artstation\",\n",
671673
"}"
672674
],
673675
"outputs": [],
@@ -726,7 +728,7 @@
726728
" seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
727729
"\n",
728730
" #@markdown **Grid Settings**\n",
729-
" make_grid = True #@param {type:\"boolean\"}\n",
731+
" make_grid = False #@param {type:\"boolean\"}\n",
730732
" grid_rows = 2 #@param \n",
731733
"\n",
732734
" precision = 'autocast' \n",
@@ -893,11 +895,11 @@
893895
" )\n",
894896
"\n",
895897
" # apply color matching\n",
896-
" if anim_args.color_coherence == 'MatchFrame0':\n",
898+
" if anim_args.color_coherence != 'None':\n",
897899
" if color_match_sample is None:\n",
898900
" color_match_sample = prev_img.copy()\n",
899901
" else:\n",
900-
" prev_img = maintain_colors(prev_img, color_match_sample, (frame_idx%2) == 0)\n",
902+
" prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n",
901903
"\n",
902904
" # apply scaling\n",
903905
" scaled_sample = prev_img * scale\n",
@@ -999,25 +1001,52 @@
9991001
"\n",
10001002
" frame_idx = 0\n",
10011003
"\n",
1002-
" for i in range(len(prompts_c_s)-1):\n",
1003-
" for j in range(anim_args.interpolate_x_frames+1):\n",
1004-
" # interpolate the text embedding\n",
1005-
" prompt1_c = prompts_c_s[i]\n",
1006-
" prompt2_c = prompts_c_s[i+1] \n",
1007-
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
1004+
" if anim_args.interpolate_key_frames:\n",
1005+
" for i in range(len(prompts_c_s)-1):\n",
1006+
" dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]\n",
1007+
" if dist_frames <= 0:\n",
1008+
" print(\"key frames duplicated or reversed. interpolation skipped.\")\n",
1009+
" return\n",
1010+
" else:\n",
1011+
" for j in range(dist_frames):\n",
1012+
" # interpolate the text embedding\n",
1013+
" prompt1_c = prompts_c_s[i]\n",
1014+
" prompt2_c = prompts_c_s[i+1] \n",
1015+
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))\n",
10081016
"\n",
1009-
" # sample the diffusion model\n",
1010-
" results = generate(args)\n",
1011-
" image = results[0]\n",
1017+
" # sample the diffusion model\n",
1018+
" results = generate(args)\n",
1019+
" image = results[0]\n",
10121020
"\n",
1013-
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
1014-
" image.save(os.path.join(args.outdir, filename))\n",
1015-
" frame_idx += 1\n",
1021+
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
1022+
" image.save(os.path.join(args.outdir, filename))\n",
1023+
" frame_idx += 1\n",
10161024
"\n",
1017-
" display.clear_output(wait=True)\n",
1018-
" display.display(image)\n",
1025+
" display.clear_output(wait=True)\n",
1026+
" display.display(image)\n",
10191027
"\n",
1020-
" args.seed = next_seed(args)\n",
1028+
" args.seed = next_seed(args)\n",
1029+
"\n",
1030+
" else:\n",
1031+
" for i in range(len(prompts_c_s)-1):\n",
1032+
" for j in range(anim_args.interpolate_x_frames+1):\n",
1033+
" # interpolate the text embedding\n",
1034+
" prompt1_c = prompts_c_s[i]\n",
1035+
" prompt2_c = prompts_c_s[i+1] \n",
1036+
" args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n",
1037+
"\n",
1038+
" # sample the diffusion model\n",
1039+
" results = generate(args)\n",
1040+
" image = results[0]\n",
1041+
"\n",
1042+
" filename = f\"{args.timestring}_{frame_idx:05}.png\"\n",
1043+
" image.save(os.path.join(args.outdir, filename))\n",
1044+
" frame_idx += 1\n",
1045+
"\n",
1046+
" display.clear_output(wait=True)\n",
1047+
" display.display(image)\n",
1048+
"\n",
1049+
" args.seed = next_seed(args)\n",
10211050
"\n",
10221051
" # generate the last prompt\n",
10231052
" args.init_c = prompts_c_s[-1]\n",
@@ -1110,7 +1139,7 @@
11101139
"accelerator": "GPU",
11111140
"colab": {
11121141
"collapsed_sections": [],
1113-
"name": "Deforum_Stable_Diffusion_+_Interpolation.ipynb",
1142+
"name": "Deforum_Stable_Diffusion.ipynb",
11141143
"provenance": [],
11151144
"private_outputs": true
11161145
},

0 commit comments

Comments
 (0)