Skip to content

Enable ONNX export of GPU loaded SVD/SVD-XT UNet models #6562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/unets/unet_spatio_temporal_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def forward(

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
if torch.is_tensor(num_frames):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can num_frames ever be a tensor? Can you give an example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten num_frames is created as a CPU tensor during the tracing step of the ONNX export. I have also provided a script to reproduce this behavior in the comment below

num_frames = num_frames.item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will results in num_frames to be set as a constant in the ONNX graph, is this expected (I see that it's defined in the unet config) or is there any cases where this value might vary @yiyixuxu @sayakpaul ? If yes we should move the tensor to the expected device instead

Copy link
Member

@sayakpaul sayakpaul Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, num_frames can change as it's in the pipeline call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix you're right, with the suggested change num_frames is traced as a constant and that is undesirable. Would the below change be more suitable?

        if torch.is_tensor(num_frames):
            num_frames = num_frames.to(sample.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix pinging to follow up on this. I ran the ONNX export of the UNET model using the above fix. The export runs successfully. However, the model fails ONNXRuntime Inference with the error below

ort_s = ort.InferenceSession(model_path)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from svd/svd_unet.onnx failed:Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (/up_blocks.3/attentions.2/time_mixer/Where).

I'd appreciate your input on how we can move this PR along. At the moment, the ONNX export for the SpatioTemporal UNET is broken. There are 2 ways to enable the export

  1. Make num_frames a scalar. This results in num_frames being an ONNX constant and not changeable during inference.
  2. Move num_frames to device during ONNX export. The ONNXRuntime Inference fails.

Option 1 sets num_frames to a Constant, but passes ONNXRuntime Inference - I'm leaning towards this as it results in immediate usability of the exported model. Happy to hear if you have alternative suggestions to the options suggested.

Copy link
Contributor

@echarlaix echarlaix Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix pinging to follow up on this. I ran the ONNX export of the UNET model using the above fix. The export runs successfully. However, the model fails ONNXRuntime Inference with the error below

Did you try to cast num_frames to a different data type ? Also were you able to locate where in the graph is this issue coming from ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix num_frames can only be an integer type as it is the repeats input to repeat_interleave. int64 and int32 are the only cast options and neither resolve the issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix following up on this

timesteps = timesteps.expand(batch_size)

t_emb = self.time_proj(timesteps)
Expand Down