@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
957
957
return freqs_t , freqs_h , freqs_w , grid_t , grid_h , grid_w
958
958
959
959
960
- def get_2d_rotary_pos_embed (embed_dim , crops_coords , grid_size , use_real = True ):
960
+ def get_2d_rotary_pos_embed (
961
+ embed_dim , crops_coords , grid_size , use_real = True , device : Optional [torch .device ] = None , output_type : str = "np"
962
+ ):
963
+ """
964
+ RoPE for image tokens with 2d structure.
965
+
966
+ Args:
967
+ embed_dim: (`int`):
968
+ The embedding dimension size
969
+ crops_coords (`Tuple[int]`)
970
+ The top-left and bottom-right coordinates of the crop.
971
+ grid_size (`Tuple[int]`):
972
+ The grid size of the positional embedding.
973
+ use_real (`bool`):
974
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
975
+ device: (`torch.device`, **optional**):
976
+ The device used to create tensors.
977
+
978
+ Returns:
979
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
980
+ """
981
+ if output_type == "np" :
982
+ deprecation_message = (
983
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984
+ " `from_numpy` is no longer required."
985
+ " Pass `output_type='pt' to use the new version now."
986
+ )
987
+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
988
+ return _get_2d_rotary_pos_embed_np (
989
+ embed_dim = embed_dim ,
990
+ crops_coords = crops_coords ,
991
+ grid_size = grid_size ,
992
+ use_real = use_real ,
993
+ )
994
+ start , stop = crops_coords
995
+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
996
+ grid_h = torch .linspace (
997
+ start [0 ], stop [0 ] * (grid_size [0 ] - 1 ) / grid_size [0 ], grid_size [0 ], device = device , dtype = torch .float32
998
+ )
999
+ grid_w = torch .linspace (
1000
+ start [1 ], stop [1 ] * (grid_size [1 ] - 1 ) / grid_size [1 ], grid_size [1 ], device = device , dtype = torch .float32
1001
+ )
1002
+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" )
1003
+ grid = torch .stack (grid , dim = 0 ) # [2, W, H]
1004
+
1005
+ grid = grid .reshape ([2 , 1 , * grid .shape [1 :]])
1006
+ pos_embed = get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = use_real )
1007
+ return pos_embed
1008
+
1009
+
1010
+ def _get_2d_rotary_pos_embed_np (embed_dim , crops_coords , grid_size , use_real = True ):
961
1011
"""
962
1012
RoPE for image tokens with 2d structure.
963
1013
0 commit comments