@@ -300,7 +300,8 @@ def forward(self, hidden_states):
300
300
hidden_states = F .pad (hidden_states , (self .pad ,) * 2 , self .pad_mode )
301
301
weight = hidden_states .new_zeros ([hidden_states .shape [1 ], hidden_states .shape [1 ], self .kernel .shape [0 ]])
302
302
indices = torch .arange (hidden_states .shape [1 ], device = hidden_states .device )
303
- weight [indices , indices ] = self .kernel .to (weight )
303
+ kernel = self .kernel .to (weight )[None , :].expand (hidden_states .shape [1 ], - 1 )
304
+ weight [indices , indices ] = kernel
304
305
return F .conv1d (hidden_states , weight , stride = 2 )
305
306
306
307
@@ -316,7 +317,8 @@ def forward(self, hidden_states, temb=None):
316
317
hidden_states = F .pad (hidden_states , ((self .pad + 1 ) // 2 ,) * 2 , self .pad_mode )
317
318
weight = hidden_states .new_zeros ([hidden_states .shape [1 ], hidden_states .shape [1 ], self .kernel .shape [0 ]])
318
319
indices = torch .arange (hidden_states .shape [1 ], device = hidden_states .device )
319
- weight [indices , indices ] = self .kernel .to (weight )
320
+ kernel = self .kernel .to (weight )[None , :].expand (hidden_states .shape [1 ], - 1 )
321
+ weight [indices , indices ] = kernel
320
322
return F .conv_transpose1d (hidden_states , weight , stride = 2 , padding = self .pad * 2 + 1 )
321
323
322
324
0 commit comments