Skip to content

Use torch in get_2d_rotary_pos_embed #10155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 18, 2024

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Dec 9, 2024

What does this PR do?

Refactors get_2d_rotary_pos_embed to use torch instead of numpy, and adds device argument so that tensors can be created on e.g. cuda.

Usage of get_2d_rotary_pos_embed in HunyuanDiT pipelines is updated to pass device.

torch and numpy versions match numerically.

Reproduction
from diffusers.models.embeddings import get_2d_rotary_pos_embed
import torch


def get_resize_crop_region_for_grid(src, tgt_size):
  th = tw = tgt_size
  h, w = src

  r = h / w

  # resize
  if r > 1:
      resize_height = th
      resize_width = int(round(th / h * w))
  else:
      resize_width = tw
      resize_height = int(round(tw / w * h))

  crop_top = int(round((th - resize_height) / 2.0))
  crop_left = int(round((tw - resize_width) / 2.0))

  return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)


height, width = 1024, 1024
height = int((height // 16) * 16)
width = int((width // 16) * 16)
num_attention_heads = 16
attention_head_dim = 88
patch_size = 2
inner_dim = num_attention_heads * attention_head_dim
grid_height = height // 8 // patch_size
grid_width = width // 8 // patch_size
base_size = 512 // 8 // patch_size
grid_crops_coords = get_resize_crop_region_for_grid(
  (grid_height, grid_width), base_size
)
image_rotary_emb_np = get_2d_rotary_pos_embed(
  inner_dim // num_attention_heads,
  grid_crops_coords,
  (grid_height, grid_width),
  output_type="np",
)

image_rotary_emb = get_2d_rotary_pos_embed(
  inner_dim // num_attention_heads,
  grid_crops_coords,
  (grid_height, grid_width),
  output_type="pt",
)

torch.testing.assert_close(image_rotary_emb[0], image_rotary_emb_np[0])
torch.testing.assert_close(image_rotary_emb[1], image_rotary_emb_np[1])

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Downstream usage - should be ok, this was already returning torch.Tensor and integrations are handling device casting

@yiyixuxu
Copy link
Collaborator

did we run hunyuan test?

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Checkpoint used in the slow test is 404

"XCLiu/HunyuanDiT-0523", revision="refs/pr/2", torch_dtype=torch.float16

https://huggingface.co/XCLiu/HunyuanDiT-0523

@yiyixuxu
Copy link
Collaborator

just run its docstring example manually would be fine for now
we should update the test too

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

There a slight change to the image.

import torch
from diffusers import HunyuanDiTPipeline
pipe = HunyuanDiTPipeline.from_pretrained(
    "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
)
pipe.to("cuda")
prompt = "An astronaut riding a horse"
image = pipe(prompt, generator=torch.Generator("cuda").manual_seed(0)).images[0]

Main:
main

PR:
pr

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Unclear why though, I'll run the test again. Edit: I haven't ran the reproduction on CUDA, might account for the difference.

>>> torch.abs(image_rotary_emb[0].flatten() - image_rotary_emb_np[0].flatten()).max()
tensor(0.)
>>> torch.abs(image_rotary_emb[1].flatten() - image_rotary_emb_np[1].flatten()).max()
tensor(0.)

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Yes there's a very minor difference when we create the tensors on CUDA.

>>> torch.abs(image_rotary_emb[0].cpu().flatten() - image_rotary_emb_np[0].flatten()).max()
tensor(1.7881e-07)
>>> torch.abs(image_rotary_emb[1].cpu().flatten() - image_rotary_emb_np[1].flatten()).max()
tensor(1.7881e-07)

It's below PyTorch's tolerance for float32 though https://pytorch.org/docs/stable/testing.html

cc @yiyixuxu

@DN6 DN6 added the roadmap Add to current release roadmap label Dec 11, 2024
@hlky hlky force-pushed the np-get-2d-rotary-pos-embed branch from af5ecd9 to f2e7731 Compare December 17, 2024 16:46
@hlky
Copy link
Contributor Author

hlky commented Dec 17, 2024

I've added output_type and a deprecation message as in #10156

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @hlky !

@yiyixuxu
Copy link
Collaborator

It's a bit of a surprise that something numerically < 1e-7 would cause a visual difference like this, but it is not worse, is it? my eyes are not very good with it

@hlky
Copy link
Contributor Author

hlky commented Dec 17, 2024

It is surprising, we can run more tests before merge if you want, the visual difference is acceptable imo.

@yiyixuxu yiyixuxu merged commit 0ac52d6 into huggingface:main Dec 18, 2024
11 of 12 checks passed
@yiyixuxu
Copy link
Collaborator

thanks @hlky great work as always:)
I think this pipeline might not be completely deterministic to begin with (but don't have time to look into now)

Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* Use `torch` in `get_2d_rotary_pos_embed`

* Add deprecation
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Use `torch` in `get_2d_rotary_pos_embed`

* Add deprecation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
close-to-merge roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants