Skip to content

Commit 2d8e089

Browse files
williambermandg845
authored andcommitted
[{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (huggingface#3479)
explicit view kernel size as number elements in flattened indices
1 parent 147da83 commit 2d8e089

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/models/unet_1d_blocks.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def forward(self, hidden_states):
300300
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
301301
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
302302
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
303-
weight[indices, indices] = self.kernel.to(weight)
303+
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
304+
weight[indices, indices] = kernel
304305
return F.conv1d(hidden_states, weight, stride=2)
305306

306307

@@ -316,7 +317,8 @@ def forward(self, hidden_states, temb=None):
316317
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
317318
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
318319
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
319-
weight[indices, indices] = self.kernel.to(weight)
320+
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
321+
weight[indices, indices] = kernel
320322
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
321323

322324

0 commit comments

Comments
 (0)