Skip to content

Add Kandinsky 2.1 #3308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 188 commits into from
May 25, 2023
Merged

Add Kandinsky 2.1 #3308

merged 188 commits into from
May 25, 2023

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented May 1, 2023

this PR add Kandinsky2.1 to diffusers

#2985

original codebase: https://github.com/ai-forever/Kandinsky-2

to-do:

  • add image prior diffusion (with prior_tokenizer, prior_text_encoder, prior_scheduler)
  • text_proj
  • add unet
  • add text-to-image diffusion pipeline
  • add image_encoder, text_encoder, tokenizer
  • add inpainting pipeline
  • MoVQ decoder
  • MoVQ encoder
  • tests
  • docs
  • add img2img pipeline [WIP] Add img2img #3426
from diffusers import KandinskyPipeline, KandinskyPriorPipeline

import torch
import numpy as np
device = "cuda"

# # inputs
prompt= "red cat, 4k photo"
batch_size=1 


# # create prior 
pipe_prior = KandinskyPriorPipeline.from_pretrained("YiYiXu/Kandinsky-prior")
pipe_prior.to("cuda")

# use prior to generate image_emb based on our prompt
generator = torch.Generator(device=device).manual_seed(0)
out = pipe_prior(prompt, generator=generator,)
image_emb = out.images
zero_image_emb = out.zero_embeds

# create diffuser pipeline
pipe = KandinskyPipeline.from_pretrained("YiYiXu/Kandinsky")
pipe.to(device)


generator = torch.Generator(device="cuda").manual_seed(0)
out = pipe(
    prompt,
    image_embeds=image_emb,
    negative_image_embeds =zero_image_emb,
    height=768,
    width=768,
    num_inference_steps=100,
    generator=generator )

image = out.images[0]
image.save("cat.png")

yiyi_test_pipe_kandinsky5_out_new

use inpainting pipeline to add a hat

from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
from diffusers.utils import load_image

import torch
import numpy as np

device = "cuda"

REPO_PRIOR = "YiYiXu/Kandinsky-prior"
REPO_INPAINT = "YiYiXu/Kandinsky-inpaint"

# # inputs
prompt= "a hat"
batch_size=1 

# # create prior 
pipe_prior = KandinskyPriorPipeline.from_pretrained(REPO_PRIOR)
pipe_prior.to("cuda")

# use prior to generate image_emb based on our prompt
generator = torch.Generator(device=device).manual_seed(0)
out = pipe_prior(prompt, generator=generator,)
image_emb = out.images
zero_image_emb = out.zero_embeds

# create diffuser pipeline
pipe = KandinskyInpaintPipeline.from_pretrained(REPO_INPAINT )
pipe.to(device)

init_image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" 
    "/kandinsky/cat.png")

mask = np.ones((768, 768), dtype=np.float32)
mask[:250,250:-250] =  0

generator = torch.Generator(device="cuda").manual_seed(0)
out = pipe(
    prompt,
    image=init_image,
    mask_image=mask,
    image_embeds=image_emb,
    negative_image_embeds =zero_image_emb,
    height=768,
    width=768,
    num_inference_steps=150,
    generator=generator )

image = out.images[0]
image.save("cat_with_hat.png")

cat_with_hat

image-to-image generation

from diffusers.utils import load_image
from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline, DDIMScheduler
import torch

REPO_PRIOR = "YiYiXu/Kandinsky-prior"
REPO = "YiYiXu/Kandinsky"

model_dtype = torch.float16

prompt = "A red cartoon frog, 4k"

init_image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" 
    "/kandinsky/frog.png")

# create prior 
pipe_prior = KandinskyPriorPipeline.from_pretrained(REPO_PRIOR, torch_dtype=model_dtype)
pipe_prior.to("cuda")

# use prior to generate image_emb based on our prompt
generator = torch.Generator(device='cuda').manual_seed(0)
out = pipe_prior(prompt, num_inference_steps=25, generator=generator,).
image_emb = out.images
zero_image_emb = out.zero_embeds

# create img2img pipeline
pipe = KandinskyImg2ImgPipeline.from_pretrained(REPO, torch_dtype=model_dtype)

# create ddim scheduler 
ddim_config = {
    "num_train_timesteps": 1000,
    "beta_schedule":  "linear",
    "beta_start": 0.00085,
    "beta_end":0.012,
    "clip_sample" : False,
    "set_alpha_to_one" : False, 
    "steps_offset" : 0,
    "prediction_type" : "epsilon",
    "thresholding" : False,
}

ddim_scheduler = DDIMScheduler(**ddim_config)
pipe.scheduler = ddim_scheduler

pipe.to("cuda")


generator = torch.Generator(device='cuda').manual_seed(0)
out = pipe(
    prompt=prompt, 
    image=init_image, 
    height=768, 
    width=768, 
    num_inference_steps=100, 
    generator=generator, 
    image_embeds=image_emb, 
    negative_image_embeds=zero_image_emb, 
    strength=0.2,) 

out.images[0].save("img2img_frog_out.png")

frog

yiyi_test_img2img1_out

image mixing

from diffusers import KandinskyPriorPipeline, KandinskyPipeline
from diffusers.utils import load_image
import PIL

import torch

from torchvision import transforms

REPO_PRIOR = "YiYiXu/Kandinsky-prior"
REPO = "YiYiXu/Kandinsky"

model_dtype = torch.float16


# create prior 
pipe_prior = KandinskyPriorPipeline.from_pretrained(REPO_PRIOR, torch_dtype=model_dtype)
pipe_prior.to("cuda")

# we will use prior to create image_emb and zero_image_emb

img1 = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" 
    "/kandinsky/cat.png")

img2 = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" 
    "/kandinsky/starry_night.jpeg")

images_texts = ["a cat", img1, img2 ]
weights = [0.3,0.3,0.4]

image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights,)

pipe = KandinskyPipeline.from_pretrained(REPO, torch_dtype =model_dtype )
pipe.to("cuda")

generator = torch.Generator(device="cuda").manual_seed(0)
out = pipe(
    "",
    image_embeds=image_emb,
    negative_image_embeds =zero_image_emb,
    height=768,
    width=768,
    num_inference_steps=150,
    generator=generator )

image = out.images[0]
image.save("starry_cat.png")

starry_cat

@yiyixuxu yiyixuxu mentioned this pull request May 1, 2023
2 tasks
@isamu-isozaki
Copy link
Contributor

@yiyixuxu Very cool! Just thought to mention but we did port MoVQ in huggingface/open-muse here which might help

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 1, 2023

@isamu-isozaki thanks!

@ayushtues
Copy link
Contributor

ayushtues commented May 2, 2023

Very cool @yiyixuxu, can you tell me how you tested if the prior model pipeline is working and if the weights were loading? Would it be cool to have a temporary jupyter-notebook handy for testing the pipeline, and its individual components, if the weights are loading, hacky debugging etc

@ayushtues
Copy link
Contributor

ayushtues commented May 2, 2023

For the MOVQ, it is practically the same as the VQVAE model already in diffusers https://github.com/huggingface/diffusers/blob/kandinsky/src/diffusers/models/vq_model.py, with the Encoder, and the VectorQuantizer being exactly the same, but in the decoder, it just uses a different custom normalization layer in the decoder (SpatialNorm) which takes an extra embedding as input, than GroupNorm in VQVAE, the rest of the implementation of the decoder is also exactly the same and thus needs changes to the attention/resnet building blocks, which are also the same as the ones present in diffusers, except the normalization layer (they use Groupnorm, and need to use SpatialNorm now).

We can either parametrize the attention/resnet building blocks and VQVAE in diffusers to support using a different normalization layer and an additional embedding input, or copy them with the minimal changes in the Kandinsky pipeline if we feel the normalization layer is not general enough to change the existing implementations.

Would love to hear opinions on this!

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 2, 2023

@ayushtues let's add SpatialNorm to the blocks

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 2, 2023

@ayushtues this is an example script I use to do a quick compare along the process

Note that this might not work for you because I had to go into the original repo to hardcode a few things to make sure we can reproduce ( including changing the the noise construction to match diffusers' and passing a generator down, I don't think you will need to do this for decoder) - this is just an example so that you can use a similar process

import numpy as np
from kandinsky2 import get_kandinsky2
import torch
model = get_kandinsky2('cuda', task_type='text2img', model_version='2.1', use_flash_attention=False)

prompt= "red cat, 4k photo"
batch_size=1 
guidance_scale=4
prior_cf_scale=4,
prior_steps="5"
negative_prior_prompt=""

# generate clip embeddings
image_emb = model.generate_clip_emb(
    prompt,
    batch_size=batch_size,
    prior_cf_scale=prior_cf_scale,
    prior_steps=prior_steps,
    negative_prior_prompt=negative_prior_prompt,
        )
print(f"image_emb:{image_emb.shape},{image_emb.sum()}")




# diffusers
from diffusers import KandinskyPipeline, PriorTransformer
import diffusers

pipe_prior = KandinskyPipeline.from_pretrained("YiYiXu/test-kandinsky")
pipe_prior.to("cuda")

generator = torch.Generator(device="cuda").manual_seed(0)
image_emb_d = pipe_prior(
    prompt,
    generator=generator,
)

print(f"image_embeddings:{image_emb_d.shape},{image_emb_d.sum()}")
print("compare results:")
print(np.max(np.abs(image_emb_d.detach().cpu().numpy() - image_emb.detach().cpu().numpy())))

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@ayushtues ayushtues mentioned this pull request May 4, 2023
3 tasks
@ayushtues
Copy link
Contributor

ayushtues commented May 4, 2023

Started a PR #3330 for adding the decoder, was able to load the pretrained weights of the MOVQ model into diffusers based VQModel, with minimal changes. Need to ensure if forward passes are also the same next

@ayushtues
Copy link
Contributor

ayushtues commented May 4, 2023

Okay the outputs of the forward pass are in 1e-4 of each other for the movq decoder and 1e-5 for the movq encoder and seem similar, so should be okay.

Can integrate it into the pipeline next, added a PR for the weights in the diffuser model repo @yiyixuxu meanwhile

@seruva19 seruva19 mentioned this pull request May 4, 2023
2 tasks
@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 5, 2023

@ayushtues

Thanks for adding the decoder so fast! super awesome job! 😇🤗👏👍

I think we can wrap up Kandinsky soon! A few tasks left (I ranked them from easy to difficult based on my subjective judgment 😂) - let me know if you are interested in taking any of these. I will help you as much as you need of course :)

  1. image encoder: I think it's only used to generate zero image embedding for unconditional tokens here https://github.com/ai-forever/Kandinsky-2/blob/main/kandinsky2/kandinsky2_1_model.py#L323, it's just the clip image encoder
from transformers import CLIPVisionModelWithProjection
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda")
  1. second text_encoder and tokenizer for the text-to-image diffusion process with Unet
  2. second scheduler for text-to-image diffusion process: I think we can just use UnCLIPScheduler https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_unclip.py out-of-box with:
  3. put everything together into the pipeline!

@@ -0,0 +1,77 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@ayushtues
Copy link
Contributor

@yiyixuxu I can take up 1, 2, and then later help with 4 when the parts are ready to combine in the pipeline; not so familiar with how schedulers integrate into diffusers, so will leave 3 to you, but will definitely want to review it and learn how they integrate into the pipeline.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 7, 2023

@ayushtues great!

@ayushtues
Copy link
Contributor

ayushtues commented May 9, 2023

@yiyixuxu where do you think we should put the multilingualCLIP model, since it's not directly available in HF, should we add it in a separate file in pipelines/Kandinsky?

@ayushtues
Copy link
Contributor

Meanwhile started another PR for task 1, 2 - #3373

@patrickvonplaten patrickvonplaten changed the title [WIP] add Kandinsky 2.1 Add Kandinsky 2.1 May 25, 2023
@patrickvonplaten
Copy link
Contributor

Good to merge!

@alexblattner
Copy link

can this use controlnet?

@patrickvonplaten
Copy link
Contributor

We should try training ControlNet on it!

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
add kandinsky2.1

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
add kandinsky2.1

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants