diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py index 3ece670e5bde..8cf2830f25ab 100644 --- a/examples/community/pipeline_hunyuandit_differential_img2img.py +++ b/examples/community/pipeline_hunyuandit_differential_img2img.py @@ -1008,6 +1008,8 @@ def __call__( self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0f4b555a2d71..f3c57103f9b8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro( return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w -def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): +def get_2d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np" +): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + device: (`torch.device`, **optional**): + The device used to create tensors. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return _get_2d_rotary_pos_embed_np( + embed_dim=embed_dim, + crops_coords=crops_coords, + grid_size=grid_size, + use_real=use_real, + ) + start, stop = crops_coords + # scale end by (stepsāˆ’1)/steps matches np.linspace(..., endpoint=False) + grid_h = torch.linspace( + start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32 + ) + grid_w = torch.linspace( + start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32 + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index 45e17f3de1e2..c8464f8108ea 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -925,7 +925,11 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index bda718cb197d..6f542cb59f46 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -798,7 +798,11 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 408992378538..dea1f12696b2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -818,7 +818,11 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", ) style = torch.tensor([0], device=device)