18
18
import torch
19
19
import torch .nn as nn
20
20
import torch .nn .functional as F
21
- from einops import rearrange , repeat
22
21
from torchvision import transforms
23
22
24
23
from ...configuration_utils import ConfigMixin , register_to_config
@@ -282,7 +281,7 @@ def __init__(
282
281
283
282
def forward (self , hidden_states : torch .Tensor , fps : Optional [int ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
284
283
batch_size , num_channels , num_frames , height , width = hidden_states .shape
285
- rope_sizes = [num_frames // self .patch_size [0 ], height // self .patch_size [1 ], width // self .patch_size [2 ]]
284
+ pe_size = [num_frames // self .patch_size [0 ], height // self .patch_size [1 ], width // self .patch_size [2 ]]
286
285
287
286
h_theta = 10000.0 * self .h_ntk_factor
288
287
w_theta = 10000.0 * self .w_ntk_factor
@@ -296,28 +295,19 @@ def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tup
296
295
w_spatial_freqs = 1.0 / (w_theta ** dim_w_range )
297
296
temporal_freqs = 1.0 / (t_theta ** dim_t_range )
298
297
299
- emb_h = torch .outer (seq [: rope_sizes [1 ]], h_spatial_freqs )
300
- emb_w = torch .outer (seq [: rope_sizes [2 ]], w_spatial_freqs )
298
+ emb_h = torch .outer (seq [: pe_size [1 ]], h_spatial_freqs )[ None , :, None , :]. repeat ( pe_size [ 0 ], 1 , pe_size [ 2 ], 1 )
299
+ emb_w = torch .outer (seq [: pe_size [2 ]], w_spatial_freqs )[ None , None , :, :]. repeat ( pe_size [ 0 ], pe_size [ 1 ], 1 , 1 )
301
300
302
301
# Apply sequence scaling in temporal dimension
303
302
if fps is None :
304
303
# Images
305
- emb_t = torch .outer (seq [: rope_sizes [0 ]], temporal_freqs )
304
+ emb_t = torch .outer (seq [: pe_size [0 ]], temporal_freqs )
306
305
else :
307
306
# Videos
308
- emb_t = torch .outer (seq [: rope_sizes [0 ]] / fps * self .base_fps , temporal_freqs )
309
-
310
- freqs = torch .cat (
311
- [
312
- repeat (emb_t , "t d -> t h w d" , h = rope_sizes [1 ], w = rope_sizes [2 ]),
313
- repeat (emb_h , "h d -> t h w d" , t = rope_sizes [0 ], w = rope_sizes [2 ]),
314
- repeat (emb_w , "w d -> t h w d" , t = rope_sizes [0 ], h = rope_sizes [1 ]),
315
- ]
316
- * 2 ,
317
- dim = - 1 ,
318
- )
307
+ emb_t = torch .outer (seq [: pe_size [0 ]] / fps * self .base_fps , temporal_freqs )
319
308
320
- freqs = rearrange (freqs , "t h w d -> (t h w) d" ).float ()
309
+ emb_t = emb_t [:, None , None , :].repeat (1 , pe_size [1 ], pe_size [2 ], 1 )
310
+ freqs = torch .cat ([emb_t , emb_h , emb_w ] * 2 , dim = - 1 ).flatten (0 , 2 ).float ()
321
311
cos = torch .cos (freqs )
322
312
sin = torch .sin (freqs )
323
313
return cos , sin
0 commit comments