diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 39a8009d5af9..36eb72963ec5 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -397,6 +397,8 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML batch_size, num_frames = sample.shape[:2] + if torch.is_tensor(num_frames): + num_frames = num_frames.item() timesteps = timesteps.expand(batch_size) t_emb = self.time_proj(timesteps)