|
164 | 164 | "def add_noise(sample: torch.Tensor, noise_amt: float):\n",
|
165 | 165 | " return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n",
|
166 | 166 | "\n",
|
167 |
| - "def get_output_folder(output_path,batch_folder=None):\n", |
168 |
| - " yearMonth = time.strftime('%Y-%m/')\n", |
169 |
| - " out_path = os.path.join(output_path,yearMonth)\n", |
| 167 | + "def get_output_folder(output_path, batch_folder):\n", |
| 168 | + " out_path = os.path.join(output_path,time.strftime('%Y-%m/'))\n", |
170 | 169 | " if batch_folder != \"\":\n",
|
171 |
| - " out_path = os.path.join(out_path,batch_folder)\n", |
172 |
| - " # we will also make sure the path suffix is a slash if linux and a backslash if windows\n", |
173 |
| - " if out_path[-1] != os.path.sep:\n", |
174 |
| - " out_path += os.path.sep\n", |
| 170 | + " out_path = os.path.join(out_path, batch_folder)\n", |
175 | 171 | " os.makedirs(out_path, exist_ok=True)\n",
|
176 | 172 | " return out_path\n",
|
177 | 173 | "\n",
|
|
203 | 199 | " mask = torch.from_numpy(mask)\n",
|
204 | 200 | " return mask\n",
|
205 | 201 | "\n",
|
206 |
| - "def maintain_colors(prev_img, color_match_sample, hsv=False):\n", |
207 |
| - " if hsv:\n", |
| 202 | + "def maintain_colors(prev_img, color_match_sample, mode):\n", |
| 203 | + " if mode == 'Match Frame 0 RGB':\n", |
| 204 | + " return match_histograms(prev_img, color_match_sample, multichannel=True)\n", |
| 205 | + " elif mode == 'Match Frame 0 HSV':\n", |
208 | 206 | " prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n",
|
209 | 207 | " color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n",
|
210 | 208 | " matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n",
|
211 | 209 | " return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n",
|
212 |
| - " else:\n", |
213 |
| - " return match_histograms(prev_img, color_match_sample, multichannel=True)\n", |
| 210 | + " else: # Match Frame 0 LAB\n", |
| 211 | + " prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)\n", |
| 212 | + " color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)\n", |
| 213 | + " matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)\n", |
| 214 | + " return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)\n", |
214 | 215 | "\n",
|
215 | 216 | "\n",
|
216 | 217 | "def make_callback(sampler_name, dynamic_threshold=None, static_threshold=None, mask=None, init_latent=None, sigmas=None, sampler=None, masked_noise_modifier=1.0): \n",
|
|
330 | 331 | " with torch.no_grad():\n",
|
331 | 332 | " with precision_scope(\"cuda\"):\n",
|
332 | 333 | " with model.ema_scope():\n",
|
333 |
| - " for n in range(args.n_samples):\n", |
334 |
| - " for prompts in data:\n", |
335 |
| - " uc = None\n", |
336 |
| - " if args.scale != 1.0:\n", |
337 |
| - " uc = model.get_learned_conditioning(batch_size * [\"\"])\n", |
338 |
| - " if isinstance(prompts, tuple):\n", |
339 |
| - " prompts = list(prompts)\n", |
340 |
| - " c = model.get_learned_conditioning(prompts)\n", |
341 |
| - "\n", |
342 |
| - " if args.init_c != None:\n", |
343 |
| - " c = args.init_c\n", |
344 |
| - "\n", |
345 |
| - " if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n", |
346 |
| - " samples = sampler_fn(\n", |
347 |
| - " c=c, \n", |
348 |
| - " uc=uc, \n", |
349 |
| - " args=args, \n", |
350 |
| - " model_wrap=model_wrap, \n", |
351 |
| - " init_latent=init_latent, \n", |
352 |
| - " t_enc=t_enc, \n", |
353 |
| - " device=device, \n", |
354 |
| - " cb=callback)\n", |
| 334 | + " for prompts in data:\n", |
| 335 | + " uc = None\n", |
| 336 | + " if args.scale != 1.0:\n", |
| 337 | + " uc = model.get_learned_conditioning(batch_size * [\"\"])\n", |
| 338 | + " if isinstance(prompts, tuple):\n", |
| 339 | + " prompts = list(prompts)\n", |
| 340 | + " c = model.get_learned_conditioning(prompts)\n", |
| 341 | + "\n", |
| 342 | + " if args.init_c != None:\n", |
| 343 | + " c = args.init_c\n", |
| 344 | + "\n", |
| 345 | + " if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n", |
| 346 | + " samples = sampler_fn(\n", |
| 347 | + " c=c, \n", |
| 348 | + " uc=uc, \n", |
| 349 | + " args=args, \n", |
| 350 | + " model_wrap=model_wrap, \n", |
| 351 | + " init_latent=init_latent, \n", |
| 352 | + " t_enc=t_enc, \n", |
| 353 | + " device=device, \n", |
| 354 | + " cb=callback)\n", |
| 355 | + " else:\n", |
| 356 | + " # args.sampler == 'plms' or args.sampler == 'ddim':\n", |
| 357 | + " if init_latent is not None and args.strength > 0:\n", |
| 358 | + " z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n", |
355 | 359 | " else:\n",
|
356 |
| - " # args.sampler == 'plms' or args.sampler == 'ddim':\n", |
357 |
| - " if init_latent is not None and args.strength > 0:\n", |
358 |
| - " z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n", |
359 |
| - " else:\n", |
360 |
| - " z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n", |
361 |
| - " samples = sampler.decode(z_enc, \n", |
362 |
| - " c, \n", |
363 |
| - " t_enc, \n", |
364 |
| - " unconditional_guidance_scale=args.scale,\n", |
365 |
| - " unconditional_conditioning=uc,\n", |
366 |
| - " img_callback=callback)\n", |
367 |
| - "\n", |
368 |
| - " if return_latent:\n", |
369 |
| - " results.append(samples.clone())\n", |
370 |
| - "\n", |
371 |
| - " x_samples = model.decode_first_stage(samples)\n", |
372 |
| - " if return_sample:\n", |
373 |
| - " results.append(x_samples.clone())\n", |
374 |
| - "\n", |
375 |
| - " x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n", |
376 |
| - "\n", |
377 |
| - " if return_c:\n", |
378 |
| - " results.append(c.clone())\n", |
379 |
| - "\n", |
380 |
| - " for x_sample in x_samples:\n", |
381 |
| - " x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n", |
382 |
| - " image = Image.fromarray(x_sample.astype(np.uint8))\n", |
383 |
| - " results.append(image)\n", |
| 360 | + " z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n", |
| 361 | + " samples = sampler.decode(z_enc, \n", |
| 362 | + " c, \n", |
| 363 | + " t_enc, \n", |
| 364 | + " unconditional_guidance_scale=args.scale,\n", |
| 365 | + " unconditional_conditioning=uc,\n", |
| 366 | + " img_callback=callback)\n", |
| 367 | + "\n", |
| 368 | + " if return_latent:\n", |
| 369 | + " results.append(samples.clone())\n", |
| 370 | + "\n", |
| 371 | + " x_samples = model.decode_first_stage(samples)\n", |
| 372 | + " if return_sample:\n", |
| 373 | + " results.append(x_samples.clone())\n", |
| 374 | + "\n", |
| 375 | + " x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n", |
| 376 | + "\n", |
| 377 | + " if return_c:\n", |
| 378 | + " results.append(c.clone())\n", |
| 379 | + "\n", |
| 380 | + " for x_sample in x_samples:\n", |
| 381 | + " x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n", |
| 382 | + " image = Image.fromarray(x_sample.astype(np.uint8))\n", |
| 383 | + " results.append(image)\n", |
384 | 384 | " return results\n",
|
385 | 385 | "\n",
|
386 | 386 | "def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n",
|
|
569 | 569 | " scale_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
|
570 | 570 | "\n",
|
571 | 571 | " #@markdown ####**Coherence:**\n",
|
572 |
| - " color_coherence = 'MatchFrame0' #@param ['None', 'MatchFrame0'] {type:'string'}\n", |
| 572 | + " color_coherence = 'Match Frame 0 HSV' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n", |
573 | 573 | "\n",
|
574 | 574 | " #@markdown ####**Video Input:**\n",
|
575 | 575 | " video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
|
576 | 576 | " extract_nth_frame = 1#@param {type:\"number\"}\n",
|
577 | 577 | "\n",
|
578 | 578 | " #@markdown ####**Interpolation:**\n",
|
| 579 | + " interpolate_key_frames = False #@param {type:\"boolean\"}\n", |
579 | 580 | " interpolate_x_frames = 4 #@param {type:\"number\"}\n",
|
580 | 581 | "\n",
|
581 | 582 | " return locals()\n",
|
|
657 | 658 | "id": "2ujwkGZTcGev"
|
658 | 659 | },
|
659 | 660 | "source": [
|
| 661 | + "\n", |
660 | 662 | "prompts = [\n",
|
661 | 663 | " \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n",
|
662 | 664 | " \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n",
|
|
665 | 667 | "\n",
|
666 | 668 | "animation_prompts = {\n",
|
667 | 669 | " 0: \"a beautiful apple, trending on Artstation\",\n",
|
668 |
| - " 10: \"a beautiful banana, trending on Artstation\",\n", |
669 |
| - " 100: \"a beautiful coconut, trending on Artstation\",\n", |
670 |
| - " 101: \"a beautiful durian, trending on Artstation\",\n", |
| 670 | + " 20: \"a beautiful banana, trending on Artstation\",\n", |
| 671 | + " 30: \"a beautiful coconut, trending on Artstation\",\n", |
| 672 | + " 40: \"a beautiful durian, trending on Artstation\",\n", |
671 | 673 | "}"
|
672 | 674 | ],
|
673 | 675 | "outputs": [],
|
|
726 | 728 | " seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
|
727 | 729 | "\n",
|
728 | 730 | " #@markdown **Grid Settings**\n",
|
729 |
| - " make_grid = True #@param {type:\"boolean\"}\n", |
| 731 | + " make_grid = False #@param {type:\"boolean\"}\n", |
730 | 732 | " grid_rows = 2 #@param \n",
|
731 | 733 | "\n",
|
732 | 734 | " precision = 'autocast' \n",
|
|
893 | 895 | " )\n",
|
894 | 896 | "\n",
|
895 | 897 | " # apply color matching\n",
|
896 |
| - " if anim_args.color_coherence == 'MatchFrame0':\n", |
| 898 | + " if anim_args.color_coherence != 'None':\n", |
897 | 899 | " if color_match_sample is None:\n",
|
898 | 900 | " color_match_sample = prev_img.copy()\n",
|
899 | 901 | " else:\n",
|
900 |
| - " prev_img = maintain_colors(prev_img, color_match_sample, (frame_idx%2) == 0)\n", |
| 902 | + " prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n", |
901 | 903 | "\n",
|
902 | 904 | " # apply scaling\n",
|
903 | 905 | " scaled_sample = prev_img * scale\n",
|
|
999 | 1001 | "\n",
|
1000 | 1002 | " frame_idx = 0\n",
|
1001 | 1003 | "\n",
|
1002 |
| - " for i in range(len(prompts_c_s)-1):\n", |
1003 |
| - " for j in range(anim_args.interpolate_x_frames+1):\n", |
1004 |
| - " # interpolate the text embedding\n", |
1005 |
| - " prompt1_c = prompts_c_s[i]\n", |
1006 |
| - " prompt2_c = prompts_c_s[i+1] \n", |
1007 |
| - " args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n", |
| 1004 | + " if anim_args.interpolate_key_frames:\n", |
| 1005 | + " for i in range(len(prompts_c_s)-1):\n", |
| 1006 | + " dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]\n", |
| 1007 | + " if dist_frames <= 0:\n", |
| 1008 | + " print(\"key frames duplicated or reversed. interpolation skipped.\")\n", |
| 1009 | + " return\n", |
| 1010 | + " else:\n", |
| 1011 | + " for j in range(dist_frames):\n", |
| 1012 | + " # interpolate the text embedding\n", |
| 1013 | + " prompt1_c = prompts_c_s[i]\n", |
| 1014 | + " prompt2_c = prompts_c_s[i+1] \n", |
| 1015 | + " args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))\n", |
1008 | 1016 | "\n",
|
1009 |
| - " # sample the diffusion model\n", |
1010 |
| - " results = generate(args)\n", |
1011 |
| - " image = results[0]\n", |
| 1017 | + " # sample the diffusion model\n", |
| 1018 | + " results = generate(args)\n", |
| 1019 | + " image = results[0]\n", |
1012 | 1020 | "\n",
|
1013 |
| - " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", |
1014 |
| - " image.save(os.path.join(args.outdir, filename))\n", |
1015 |
| - " frame_idx += 1\n", |
| 1021 | + " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", |
| 1022 | + " image.save(os.path.join(args.outdir, filename))\n", |
| 1023 | + " frame_idx += 1\n", |
1016 | 1024 | "\n",
|
1017 |
| - " display.clear_output(wait=True)\n", |
1018 |
| - " display.display(image)\n", |
| 1025 | + " display.clear_output(wait=True)\n", |
| 1026 | + " display.display(image)\n", |
1019 | 1027 | "\n",
|
1020 |
| - " args.seed = next_seed(args)\n", |
| 1028 | + " args.seed = next_seed(args)\n", |
| 1029 | + "\n", |
| 1030 | + " else:\n", |
| 1031 | + " for i in range(len(prompts_c_s)-1):\n", |
| 1032 | + " for j in range(anim_args.interpolate_x_frames+1):\n", |
| 1033 | + " # interpolate the text embedding\n", |
| 1034 | + " prompt1_c = prompts_c_s[i]\n", |
| 1035 | + " prompt2_c = prompts_c_s[i+1] \n", |
| 1036 | + " args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n", |
| 1037 | + "\n", |
| 1038 | + " # sample the diffusion model\n", |
| 1039 | + " results = generate(args)\n", |
| 1040 | + " image = results[0]\n", |
| 1041 | + "\n", |
| 1042 | + " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", |
| 1043 | + " image.save(os.path.join(args.outdir, filename))\n", |
| 1044 | + " frame_idx += 1\n", |
| 1045 | + "\n", |
| 1046 | + " display.clear_output(wait=True)\n", |
| 1047 | + " display.display(image)\n", |
| 1048 | + "\n", |
| 1049 | + " args.seed = next_seed(args)\n", |
1021 | 1050 | "\n",
|
1022 | 1051 | " # generate the last prompt\n",
|
1023 | 1052 | " args.init_c = prompts_c_s[-1]\n",
|
|
1110 | 1139 | "accelerator": "GPU",
|
1111 | 1140 | "colab": {
|
1112 | 1141 | "collapsed_sections": [],
|
1113 |
| - "name": "Deforum_Stable_Diffusion_+_Interpolation.ipynb", |
| 1142 | + "name": "Deforum_Stable_Diffusion.ipynb", |
1114 | 1143 | "provenance": [],
|
1115 | 1144 | "private_outputs": true
|
1116 | 1145 | },
|
|
0 commit comments