Skip to content

Use torch in get_2d_rotary_pos_embed #10155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,8 @@ def __call__(
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)

style = torch.tensor([0], device=device)
Expand Down
52 changes: 51 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w


def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
def get_2d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
):
"""
RoPE for image tokens with 2d structure.

Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
device: (`torch.device`, **optional**):
The device used to create tensors.

Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return _get_2d_rotary_pos_embed_np(
embed_dim=embed_dim,
crops_coords=crops_coords,
grid_size=grid_size,
use_real=use_real,
)
start, stop = crops_coords
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
grid_h = torch.linspace(
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
)
grid_w = torch.linspace(
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0) # [2, W, H]

grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed


def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)

style = torch.tensor([0], device=device)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)

style = torch.tensor([0], device=device)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)

style = torch.tensor([0], device=device)
Expand Down
Loading