Skip to content

Commit e4173df

Browse files
committed
remove einops
1 parent 4f1161d commit e4173df

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

src/diffusers/models/transformers/transformer_cosmos.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from einops import rearrange, repeat
2221
from torchvision import transforms
2322

2423
from ...configuration_utils import ConfigMixin, register_to_config
@@ -282,7 +281,7 @@ def __init__(
282281

283282
def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
284283
batch_size, num_channels, num_frames, height, width = hidden_states.shape
285-
rope_sizes = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
284+
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
286285

287286
h_theta = 10000.0 * self.h_ntk_factor
288287
w_theta = 10000.0 * self.w_ntk_factor
@@ -296,28 +295,19 @@ def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tup
296295
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
297296
temporal_freqs = 1.0 / (t_theta**dim_t_range)
298297

299-
emb_h = torch.outer(seq[: rope_sizes[1]], h_spatial_freqs)
300-
emb_w = torch.outer(seq[: rope_sizes[2]], w_spatial_freqs)
298+
emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
299+
emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
301300

302301
# Apply sequence scaling in temporal dimension
303302
if fps is None:
304303
# Images
305-
emb_t = torch.outer(seq[: rope_sizes[0]], temporal_freqs)
304+
emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
306305
else:
307306
# Videos
308-
emb_t = torch.outer(seq[: rope_sizes[0]] / fps * self.base_fps, temporal_freqs)
309-
310-
freqs = torch.cat(
311-
[
312-
repeat(emb_t, "t d -> t h w d", h=rope_sizes[1], w=rope_sizes[2]),
313-
repeat(emb_h, "h d -> t h w d", t=rope_sizes[0], w=rope_sizes[2]),
314-
repeat(emb_w, "w d -> t h w d", t=rope_sizes[0], h=rope_sizes[1]),
315-
]
316-
* 2,
317-
dim=-1,
318-
)
307+
emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
319308

320-
freqs = rearrange(freqs, "t h w d -> (t h w) d").float()
309+
emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
310+
freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
321311
cos = torch.cos(freqs)
322312
sin = torch.sin(freqs)
323313
return cos, sin

0 commit comments

Comments
 (0)