-
Notifications
You must be signed in to change notification settings - Fork 6k
Use torch
in get_2d_sincos_pos_embed
and get_3d_sincos_pos_embed
#10156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Draft because I noticed |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
torch
in get_2d_sincos_pos_embed
torch
in get_2d_sincos_pos_embed
and get_3d_sincos_pos_embed
ohh thanks!! |
Downstream usage (there's probably more) This will break https://github.com/kijai/ComfyUI-CogVideoXWrapper/blob/795f8b05659dfa5ec6b216fb698bcca6fda34fdb/embeddings.py#L64-L71 because we return |
just to be extra safe, maybe we can deprecate it instead
let me know what you think! |
Yes makes sense to deprecate it, by |
sure! |
Added here c5bd771 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! I left one comment
code change looks good to me! does this only affect latte? can we run the doc string example for affected models to make sure no output change before merge
src/diffusers/models/embeddings.py
Outdated
@@ -141,6 +156,66 @@ def get_3d_sincos_pos_embed( | |||
return pos_embed | |||
|
|||
|
|||
def get_3d_sincos_pos_embed_np( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it a private method _get_3d_sincos_pos_embed_np
and deprecate here too in case anymore wants to use it directly
Affected models: get_3d_sincos_pos_embed CogVideoXPatchEmbed get_2d_sincos_pos_embed_from_grid used by get_3d_sincos_pos_embed, get_2d_sincos_pos_embed get_1d_sincos_pos_embed_from_grid used by get_3d_sincos_pos_embed, get_2d_sincos_pos_embed_from_grid, LatteTransformer3DModel get_2d_sincos_pos_embed used by get_3d_sincos_pos_embed, CogView3PlusPatchEmbed, PatchEmbed (in PatchEmbed used by HunyuanDiT2DControlNetModel, SD3ControlNetModel, DiTTransformer2DModel, HunyuanDiT2DModel, LatteTransformer3DModel, PixArtTransformer2DModel, Transformer2DModel (deprecated? it's split now), AllegroTransformer3DModel, MochiTransformer3DModel, SD3Transformer2DModel, UniDiffuserModel/UTransformer2DModel (uvit) I'll test docstring examples for each model, might take some time as there's a few large/slow models. |
AllegroHash matched. AllegroPipeline.2.mp4AllegroPipeline.3.mp4I just let it run but we should follow this #10212 (comment) in the future. |
wow!!! thank you @hlky |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!
#10156) * Use torch in get_2d_sincos_pos_embed * Use torch in get_3d_sincos_pos_embed * get_1d_sincos_pos_embed_from_grid in LatteTransformer3DModel * deprecate * move deprecate, make private
What does this PR do?
Refactors get_2d_sincos_pos_embed and get_3d_sincos_pos_embed to use torch instead of numpy, and adds device argument so that tensors can be created on e.g. cuda.
Usage of get_2d_sincos_pos_embed and get_3d_sincos_pos_embed is updated to pass device where applicable (we don't specify device during initialization so we don't pass device to the function when used from
__init__
, the device from weights would just be cpu)torch and numpy versions match numerically.
Reproduction `get_2d_sincos_pos_embed`
Reproduction `get_3d_sincos_pos_embed`
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.