-
Notifications
You must be signed in to change notification settings - Fork 6k
Add Shap-E #3742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Shap-E #3742
Conversation
adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function
@patrickvonplaten When I compared the model forward pass (see equivalency test for model forward pass), the results matched nicely with the max element difference less than Not sure what to do here and appreciate any feedback/advices:) equivalency test for pipeline outputsthis script returns import torch
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 4
guidance_scale = 15.
sigma_min = 1e-3
prompt = "a shark"
# diffusers
from diffusers import ShapEPipeline
repo = "YiYiXu/shap-e"
pipe = ShapEPipeline.from_pretrained(repo)
pipe = pipe.to(device)
generator = torch.Generator(device="cuda").manual_seed(0)
latents_d = pipe(prompt, num_images_per_prompt=batch_size, generator=generator, guidance_scale=guidance_scale,num_inference_steps= 64, sigma_min=sigma_min).latents
# original
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
latents, _ = sample_latents(
batch_size=batch_size,
model=model,
diffusion=diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=False,
use_karras=True,
karras_steps=64,
sigma_min=sigma_min,
sigma_max=160,
s_churn=0,
)
# compare
print("max diff latents:")
print(np.abs(latents.reshape(4, 1024,1024).detach().cpu().numpy() - latents_d.detach().cpu().numpy()).max()) equivalency test for model forward pass
import torch
import numpy as np
import clip
from diffusers.models.prior_transformer import PriorTransformer
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from shap_e.models.download import load_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# create original model
model = load_model('text300M', device=device)
transformer = model.wrapped
# create diffusers model
path_shape = "YiYiXu/shap-e"
transformer_d = PriorTransformer.from_pretrained(path_shape, subfolder="prior").to(device)
# inputs
batch_size = 1
torch.manual_seed(0)
x = torch.randn([batch_size, 1024, 1024], device=device)
t = torch.tensor([0] * batch_size, device=device)
prompt = ["a shark"] * batch_size
# create embeddings using original clip model
clip_name = "ViT-L/14"
download_root= "/home/yiyi_huggingface_co/shap-e/shap_e_model_cache"
clip_model, _ = clip.load(clip_name, device=device, download_root=download_root)
tokenize = clip.tokenize
embeddings = clip_model.encode_text(
tokenize(list(prompt), truncate=True).to(device)
).float()
embeddings = embeddings / torch.linalg.norm(embeddings, dim=-1, keepdim=True)
# create embeddings using transformer clip
repo = "openai/clip-vit-large-patch14"
d_text_encoder = CLIPTextModelWithProjection.from_pretrained(repo).to(device)
d_tokenizer = CLIPTokenizer.from_pretrained(repo)
tokens = d_tokenizer(prompt, padding="max_length", max_length=d_tokenizer.model_max_length, truncation=True, return_tensors="pt",).input_ids
embeddings_d= d_text_encoder(tokens.to(device)).text_embeds.float()
embeddings_d = embeddings_d / torch.linalg.norm(embeddings_d, dim=-1, keepdim=True)
# compare the embeddings : 0.00019
print(f" compare embeddings: {np.abs(embeddings.detach().cpu().numpy() - embeddings_d.detach().cpu().numpy()).max()}")
# TEST1: compare the output using respective embeddings: 0.0012
# original output
out = transformer(x,t, embeddings = embeddings)
# diffusers output
out_d = transformer_d(x.permute(0,2,1), 0, embeddings_d, return_dict=False)[0]
print(" ")
print(" test1 result") # 0.0012
print((out_d.permute(0, 2, 1) - out).abs().max())
# TEST2: compare the outputs using same embedding: 4.6790e-06
# original output
out = transformer(x,t, embeddings = embeddings)
# diffusers output
out_d = transformer_d(x.permute(0,2,1), 0, embeddings, return_dict=False)[0]
print(" ")
print(" test2 result") # 4.6790e-06
print((out_d.permute(0, 2, 1) - out).abs().max()) testing script compare the pipeline output with
|
@@ -220,6 +241,22 @@ def _sigma_to_t(self, sigma, log_sigmas): | |||
t = t.reshape(sigma.shape) | |||
return t | |||
|
|||
# YiYi Notes: Taking from the origional repo, will refactor and not introduce dependency on spicy | |||
def _sigma_to_t_yiyi(self, sigma): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
In the general the design looks good to me! I just noticed that we don't have any prior transformer tests so I added them here: #3796. This PR also allows to disable the PT 2 attention processor which should help with precision issues. Could you maybe merge #3796 into your PR and once it's merged and then use Thing we're on a good way here to have a powerful new model class in |
The documentation is not available anymore as the PR was closed or merged. |
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
255.0, | ||
255.0, | ||
255.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
* refactor prior_transformer adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function * add * add * update conversion script for renderer model * refactor camera a little bit * clean up * style * fix copies * Update src/diffusers/schedulers/scheduling_heun_discrete.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * alpha_transform_type * remove step_index argument * remove get_sigmas_karras * remove _yiyi_sigma_to_t * move the rescale prompt_embeds from prior_transformer to pipeline * replace baddbmm with einsum to match origial repo * Revert "replace baddbmm with einsum to match origial repo" This reverts commit 3f6b435. * add step_index to scale_model_input * Revert "move the rescale prompt_embeds from prior_transformer to pipeline" This reverts commit 5b5a8e6. * move rescale from prior_transformer to pipeline * correct step_index in scale_model_input * remove print lines * refactor prior - reduce arguments * make style * add prior_image * arg embedding_proj_norm -> norm_embedding_proj * add pre-norm for proj_embedding * move rescale prompt from pipeline to _encode_prompt * add img2img pipeline * style * copies * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add arg: encoder_hid_proj Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add new config: norm_in_type Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add new config: added_emb_type Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py rename out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py rename config: out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * finish refactor prior_tranformer * make style * refactor renderer * fix * make style * refactor img2img * remove params_proj * add test * add upcast_softmax to prior_transformer * enable num_images_per_prompt, add save_gif utility * add * add fast test * make style * add slow test * style * add test for img2img * refactor * enable batching * style * refactor scheduler * update test * style * attempt to solve batch related tests timeout * add doc * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py Co-authored-by: Patrick von Platen <[email protected]> * hardcode rendering related config * update betas_for_alpha_bar on ddpm_scheduler * fix copies * fix * export_to_gif * style * second attempt to speed up batching tests * add doc page to index * Remove intermediate clipping * 3rd attempt to speed up batching tests * Remvoe time index * simplify scheduler * Fix more * Fix more * fix more * make style * fix schedulers * fix some more tests * finish * add one more test * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * style * apply feedbacks * style * fix copies * add one example * style * add example for img2img * fix doc * fix more doc strings * size -> frame_size * style * update doc * style * fix on doc * update repo name * improve the usage example in shap-e img2img * add usage examples in the shap-e docs. * consolidate examples. * minor fix. * update doc * Apply suggestions from code review * Apply suggestions from code review * remove upcast * Make sure background is white * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Apply suggestions from code review * Finish * Apply suggestions from code review * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Make style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* refactor prior_transformer adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function * add * add * update conversion script for renderer model * refactor camera a little bit * clean up * style * fix copies * Update src/diffusers/schedulers/scheduling_heun_discrete.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * alpha_transform_type * remove step_index argument * remove get_sigmas_karras * remove _yiyi_sigma_to_t * move the rescale prompt_embeds from prior_transformer to pipeline * replace baddbmm with einsum to match origial repo * Revert "replace baddbmm with einsum to match origial repo" This reverts commit 3f6b435. * add step_index to scale_model_input * Revert "move the rescale prompt_embeds from prior_transformer to pipeline" This reverts commit 5b5a8e6. * move rescale from prior_transformer to pipeline * correct step_index in scale_model_input * remove print lines * refactor prior - reduce arguments * make style * add prior_image * arg embedding_proj_norm -> norm_embedding_proj * add pre-norm for proj_embedding * move rescale prompt from pipeline to _encode_prompt * add img2img pipeline * style * copies * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add arg: encoder_hid_proj Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add new config: norm_in_type Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add new config: added_emb_type Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py rename out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py rename config: out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * finish refactor prior_tranformer * make style * refactor renderer * fix * make style * refactor img2img * remove params_proj * add test * add upcast_softmax to prior_transformer * enable num_images_per_prompt, add save_gif utility * add * add fast test * make style * add slow test * style * add test for img2img * refactor * enable batching * style * refactor scheduler * update test * style * attempt to solve batch related tests timeout * add doc * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py Co-authored-by: Patrick von Platen <[email protected]> * hardcode rendering related config * update betas_for_alpha_bar on ddpm_scheduler * fix copies * fix * export_to_gif * style * second attempt to speed up batching tests * add doc page to index * Remove intermediate clipping * 3rd attempt to speed up batching tests * Remvoe time index * simplify scheduler * Fix more * Fix more * fix more * make style * fix schedulers * fix some more tests * finish * add one more test * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * style * apply feedbacks * style * fix copies * add one example * style * add example for img2img * fix doc * fix more doc strings * size -> frame_size * style * update doc * style * fix on doc * update repo name * improve the usage example in shap-e img2img * add usage examples in the shap-e docs. * consolidate examples. * minor fix. * update doc * Apply suggestions from code review * Apply suggestions from code review * remove upcast * Make sure background is white * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Apply suggestions from code review * Finish * Apply suggestions from code review * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Make style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
original repo: https://github.com/openai/shap-e
text-to-3D
generated from original code

image-to-3D
image:

3d

as a reference, this is the 3d render generated with original repo with same inputs and seed

To-do: