Skip to content

Commit 6af50ff

Browse files
committed
Enable ONNX export of GPU loaded SVD/SVD-XT UNet models
* Unpack num_frames scalar if created as a (CPU) tensor in forward path Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops Signed-off-by: Rajeev Rao <[email protected]>
1 parent 79df503 commit 6af50ff

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

Diff for: src/diffusers/models/unet_spatio_temporal_condition.py

+2
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ def forward(
397397

398398
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
399399
batch_size, num_frames = sample.shape[:2]
400+
if torch.is_tensor(num_frames):
401+
num_frames = num_frames.item()
400402
timesteps = timesteps.expand(batch_size)
401403

402404
t_emb = self.time_proj(timesteps)

0 commit comments

Comments
 (0)