Skip to content

Add Shap-E #3742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 131 commits into from
Jul 6, 2023
Merged

Add Shap-E #3742

merged 131 commits into from
Jul 6, 2023

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Jun 11, 2023

original repo: https://github.com/openai/shap-e

text-to-3D

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from diffusers import ShapEPipeline


batch_size = 1
guidance_scale = 15.0
prompt = "a shark"
torch.manual_seed(0)

repo = "YiYiXu/shap-e"
pipe = ShapEPipeline.from_pretrained(repo)
pipe = pipe.to(device)


generator = torch.Generator(device="cuda").manual_seed(0)
images = pipe(
    prompt, 
    num_images_per_prompt=batch_size, 
    generator=generator, 
    guidance_scale=guidance_scale,
    num_inference_steps=64, 
    frambe_size=256, 
    output_type='pil').images

pipe.save_gif(images[0], ""shark.gif")

yiyi_test_pipeline_example_1_out

generated from original code
yiyi_run_model_decode_latent_images_1_out_0

image-to-3D

from PIL import Image
import torch
import numpy as np

from diffusers import ShapEImg2ImgPipeline


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 1
guidance_scale = 3.0

image = Image.open("corgi.png")

repo = "YiYiXu/shap-e-img2img"
pipe = ShapEImg2ImgPipeline.from_pretrained(repo)
pipe = pipe.to(device)


generator = torch.Generator(device=device).manual_seed(0)
images = pipe(
    image, 
    num_images_per_image=batch_size, 
    generator=generator, 
    guidance_scale=guidance_scale,
    num_inference_steps= 64, 
    size = 256, 
    output_type='pil').images

pipe.save(images[0], "corgi_3d.gif")

image:
corgi

3d
yiyi_test_pipeline_img2img_example_1_out

as a reference, this is the 3d render generated with original repo with same inputs and seed
yiyi_run_model_img23d_out_0

To-do:

  • refactor based on feedback
  • add image_to_3d
  • investigate more on numerical differences (compare image qualities)
  • tests & doc

adding conversion script

add pipeline

add step_index from pipeline, + remove permute

add zero pad token

remove copy from statement for betas_for_alpha_bar function
@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jun 12, 2023

shark

@patrickvonplaten
The generations seem maybe ok (I still need to compare more) but I've been really struggling to match the numerical outputs of the pipeline closely to the original repo. See below the testing script that compares pipeline outputs - it returns 0.07. We would normally like to see this number less than 1e-3 no? or should I wait until we add decoder and compare the decoded output instead?

When I compared the model forward pass (see equivalency test for model forward pass), the results matched nicely with the max element difference less than 1e-5, but I think the model is sensitive to the difference in text embedding inputs. e.g., if I run the same test with text embeddings inputs generated from the original CLIP model vs transformers CLIP model that we use, the difference will increase to 1e-3; And this discrepancy seems to be further amplified during the sampling process

Not sure what to do here and appreciate any feedback/advices:)

equivalency test for pipeline outputs

this script returns 0.07

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 4
guidance_scale = 15.
sigma_min = 1e-3
prompt = "a shark"

# diffusers 

from diffusers import ShapEPipeline
repo = "YiYiXu/shap-e"
pipe = ShapEPipeline.from_pretrained(repo)
pipe = pipe.to(device)

generator = torch.Generator(device="cuda").manual_seed(0)
latents_d = pipe(prompt, num_images_per_prompt=batch_size, generator=generator, guidance_scale=guidance_scale,num_inference_steps= 64, sigma_min=sigma_min).latents

# original
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config

model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

latents, _ = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(texts=[prompt] * batch_size),
    progress=True,
    clip_denoised=True,
    use_fp16=False,
    use_karras=True,
    karras_steps=64,
    sigma_min=sigma_min,
    sigma_max=160,
    s_churn=0,

)

# compare
print("max diff latents:") 
print(np.abs(latents.reshape(4, 1024,1024).detach().cpu().numpy() - latents_d.detach().cpu().numpy()).max())

equivalency test for model forward pass

  • TEST2 compare the model output with exactly same inputs, the maximum element difference is 4.6790e-06
  • TEST1 compare the model output with embedding generated with original CLIP model(used by the original repo) and the transformer CLIP model(used by diffusers) - and the maximum element difference is 1e-3
  • I also compared the generated text embeddings - difference 0.00019
import torch
import numpy as np
import clip

from diffusers.models.prior_transformer import PriorTransformer

from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from shap_e.models.download import load_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create original model
model = load_model('text300M', device=device)
transformer = model.wrapped

# create diffusers model
path_shape = "YiYiXu/shap-e"
transformer_d = PriorTransformer.from_pretrained(path_shape, subfolder="prior").to(device)

# inputs
batch_size = 1
torch.manual_seed(0)
x = torch.randn([batch_size, 1024, 1024], device=device)
t = torch.tensor([0] * batch_size, device=device)
prompt = ["a shark"] * batch_size

# create embeddings using original clip model
clip_name = "ViT-L/14"
download_root= "/home/yiyi_huggingface_co/shap-e/shap_e_model_cache"

clip_model, _ = clip.load(clip_name, device=device, download_root=download_root)
tokenize = clip.tokenize

embeddings = clip_model.encode_text(
        tokenize(list(prompt), truncate=True).to(device)
    ).float()
embeddings = embeddings / torch.linalg.norm(embeddings, dim=-1, keepdim=True)

# create embeddings using transformer clip 
repo = "openai/clip-vit-large-patch14"
d_text_encoder = CLIPTextModelWithProjection.from_pretrained(repo).to(device)
d_tokenizer = CLIPTokenizer.from_pretrained(repo)

tokens = d_tokenizer(prompt, padding="max_length", max_length=d_tokenizer.model_max_length, truncation=True, return_tensors="pt",).input_ids
embeddings_d= d_text_encoder(tokens.to(device)).text_embeds.float()
embeddings_d = embeddings_d / torch.linalg.norm(embeddings_d, dim=-1, keepdim=True)

# compare the embeddings : 0.00019
print(f" compare embeddings: {np.abs(embeddings.detach().cpu().numpy() - embeddings_d.detach().cpu().numpy()).max()}")


# TEST1: compare the output using respective embeddings: 0.0012

# original output
out = transformer(x,t, embeddings = embeddings)

# diffusers output
out_d = transformer_d(x.permute(0,2,1), 0, embeddings_d, return_dict=False)[0]

print(" ")
print(" test1 result") # 0.0012
print((out_d.permute(0, 2, 1) - out).abs().max())

# TEST2: compare the outputs using same embedding: 4.6790e-06

# original output
out = transformer(x,t, embeddings = embeddings)


# diffusers output
out_d = transformer_d(x.permute(0,2,1), 0, embeddings, return_dict=False)[0]

print(" ")
print(" test2 result") # 4.6790e-06
print((out_d.permute(0, 2, 1) - out).abs().max())

testing script compare the pipeline output with sigma_min = 1

I also run the equivalency test fo pipeline with sigmas from 160 ~ 1 (vs the first pipeline test was run with default sigma range 160 ~ 1e-3). This test return 5e-4 so maybe it becomes unstable when sigma gets really small

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 4
guidance_scale = 15.
sigma_min = 1e-3
prompt = "a shark"

# diffusers 

from diffusers import ShapEPipeline
repo = "YiYiXu/shap-e"
pipe = ShapEPipeline.from_pretrained(repo)
pipe = pipe.to(device)

generator = torch.Generator(device="cuda").manual_seed(0)
latents_d = pipe(prompt, num_images_per_prompt=batch_size, generator=generator, guidance_scale=guidance_scale,num_inference_steps= 64, sigma_min=sigma_min).latents

# create prompt_embeds from diffusers pipeline and pass it to the original model
 prompt_embeds = pipe._encode_prompt(
     [prompt], device, batch_size, True
 )

# original
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config

model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

latents, _ = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(embeddings=prompt_embeds[4:]),
    progress=True,
    clip_denoised=True,
    use_fp16=False,
    use_karras=True,
    karras_steps=64,
    sigma_min=sigma_min,
    sigma_max=160,
    s_churn=0,

)

# compare
print("max diff latents:") 
print(np.abs(latents.reshape(4, 1024,1024).detach().cpu().numpy() - latents_d.detach().cpu().numpy()).max()) #0.00050234795

@@ -220,6 +241,22 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t

# YiYi Notes: Taking from the origional repo, will refactor and not introduce dependency on spicy
def _sigma_to_t_yiyi(self, sigma):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@patrickvonplaten
Copy link
Contributor

In the general the design looks good to me! I just noticed that we don't have any prior transformer tests so I added them here: #3796. This PR also allows to disable the PT 2 attention processor which should help with precision issues.

Could you maybe merge #3796 into your PR and once it's merged and then use set_default_attn_processor() to improve precision in the tests?

Thing we're on a good way here to have a powerful new model class in diffusers 🚀

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title [WIP]add Shap-E Add Shap-E Jul 6, 2023
Comment on lines +571 to +573
255.0,
255.0,
255.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@patrickvonplaten patrickvonplaten merged commit 45f6d52 into main Jul 6, 2023
@patrickvonplaten patrickvonplaten deleted the shap-ee branch July 6, 2023 13:20
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* refactor prior_transformer

adding conversion script

add pipeline

add step_index from pipeline, + remove permute

add zero pad token

remove copy from statement for betas_for_alpha_bar function

* add

* add

* update conversion script for renderer model

* refactor camera a little bit

* clean up

* style

* fix copies

* Update src/diffusers/schedulers/scheduling_heun_discrete.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <[email protected]>

* alpha_transform_type

* remove step_index argument

* remove get_sigmas_karras

* remove _yiyi_sigma_to_t

* move the rescale prompt_embeds from prior_transformer to pipeline

* replace baddbmm with einsum to match origial repo

* Revert "replace baddbmm with einsum to match origial repo"

This reverts commit 3f6b435.

* add step_index to scale_model_input

* Revert "move the rescale prompt_embeds from prior_transformer to pipeline"

This reverts commit 5b5a8e6.

* move rescale from prior_transformer to pipeline

* correct step_index in scale_model_input

* remove print lines

* refactor prior - reduce arguments

* make style

* add prior_image

* arg embedding_proj_norm -> norm_embedding_proj

* add pre-norm for proj_embedding

* move rescale prompt from pipeline to _encode_prompt

* add img2img pipeline

* style

* copies

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

add arg: encoder_hid_proj

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

add new config: norm_in_type

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

add new config: added_emb_type

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

rename out_dim -> clip_embed_dim

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

rename config: out_dim -> clip_embed_dim

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* finish refactor prior_tranformer

* make style

* refactor renderer

* fix

* make style

* refactor img2img

* remove params_proj

* add test

* add upcast_softmax to prior_transformer

* enable num_images_per_prompt, add save_gif utility

* add

* add fast test

* make style

* add slow test

* style

* add test for img2img

* refactor

* enable batching

* style

* refactor scheduler

* update test

* style

* attempt to solve batch related tests timeout

* add doc

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py

Co-authored-by: Patrick von Platen <[email protected]>

* hardcode rendering related config

* update betas_for_alpha_bar on ddpm_scheduler

* fix copies

* fix

* export_to_gif

* style

* second attempt to speed up batching tests

* add doc page to index

* Remove intermediate clipping

* 3rd attempt to speed up batching tests

* Remvoe time index

* simplify scheduler

* Fix more

* Fix more

* fix more

* make style

* fix schedulers

* fix some more tests

* finish

* add one more test

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* style

* apply feedbacks

* style

* fix copies

* add one example

* style

* add example for img2img

* fix doc

* fix more doc strings

* size -> frame_size

* style

* update doc

* style

* fix on doc

* update repo name

* improve the usage example in shap-e img2img

* add usage examples in the shap-e docs.

* consolidate examples.

* minor fix.

* update doc

* Apply suggestions from code review

* Apply suggestions from code review

* remove upcast

* Make sure background is white

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

* Apply suggestions from code review

* Finish

* Apply suggestions from code review

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

* Make style

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* refactor prior_transformer

adding conversion script

add pipeline

add step_index from pipeline, + remove permute

add zero pad token

remove copy from statement for betas_for_alpha_bar function

* add

* add

* update conversion script for renderer model

* refactor camera a little bit

* clean up

* style

* fix copies

* Update src/diffusers/schedulers/scheduling_heun_discrete.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <[email protected]>

* alpha_transform_type

* remove step_index argument

* remove get_sigmas_karras

* remove _yiyi_sigma_to_t

* move the rescale prompt_embeds from prior_transformer to pipeline

* replace baddbmm with einsum to match origial repo

* Revert "replace baddbmm with einsum to match origial repo"

This reverts commit 3f6b435.

* add step_index to scale_model_input

* Revert "move the rescale prompt_embeds from prior_transformer to pipeline"

This reverts commit 5b5a8e6.

* move rescale from prior_transformer to pipeline

* correct step_index in scale_model_input

* remove print lines

* refactor prior - reduce arguments

* make style

* add prior_image

* arg embedding_proj_norm -> norm_embedding_proj

* add pre-norm for proj_embedding

* move rescale prompt from pipeline to _encode_prompt

* add img2img pipeline

* style

* copies

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

add arg: encoder_hid_proj

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

add new config: norm_in_type

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

add new config: added_emb_type

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

rename out_dim -> clip_embed_dim

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

rename config: out_dim -> clip_embed_dim

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <[email protected]>

* finish refactor prior_tranformer

* make style

* refactor renderer

* fix

* make style

* refactor img2img

* remove params_proj

* add test

* add upcast_softmax to prior_transformer

* enable num_images_per_prompt, add save_gif utility

* add

* add fast test

* make style

* add slow test

* style

* add test for img2img

* refactor

* enable batching

* style

* refactor scheduler

* update test

* style

* attempt to solve batch related tests timeout

* add doc

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py

Co-authored-by: Patrick von Platen <[email protected]>

* hardcode rendering related config

* update betas_for_alpha_bar on ddpm_scheduler

* fix copies

* fix

* export_to_gif

* style

* second attempt to speed up batching tests

* add doc page to index

* Remove intermediate clipping

* 3rd attempt to speed up batching tests

* Remvoe time index

* simplify scheduler

* Fix more

* Fix more

* fix more

* make style

* fix schedulers

* fix some more tests

* finish

* add one more test

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* style

* apply feedbacks

* style

* fix copies

* add one example

* style

* add example for img2img

* fix doc

* fix more doc strings

* size -> frame_size

* style

* update doc

* style

* fix on doc

* update repo name

* improve the usage example in shap-e img2img

* add usage examples in the shap-e docs.

* consolidate examples.

* minor fix.

* update doc

* Apply suggestions from code review

* Apply suggestions from code review

* remove upcast

* Make sure background is white

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

* Apply suggestions from code review

* Finish

* Apply suggestions from code review

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

* Make style

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants