-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Enable ONNX export of GPU loaded SVD/SVD-XT UNet models #6562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@patrickvonplaten @sayakpaul please review. |
It should reside in optimum. Cc: @echarlaix |
Hi @rajeevsrao, could you share the script you used for the export ? |
You mean patching the model in optimum ? Depending on the modifications needed, it could make sense to have it in |
ONNX is a popular interchange format. @sayakpaul I think that diffusers should also support exporting models into ONNX. Especially given that this is a easy/harmless fix. |
Here is the ONNX export script for reference
|
* Unpack num_frames scalar if created as a (CPU) tensor in forward path Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops Signed-off-by: Rajeev Rao <[email protected]>
@sayakpaul @echarlaix please suggest next steps. Thanks. |
@@ -397,6 +397,8 @@ def forward( | |||
|
|||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |||
batch_size, num_frames = sample.shape[:2] | |||
if torch.is_tensor(num_frames): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can num_frames
ever be a tensor? Can you give an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten num_frames
is created as a CPU tensor during the tracing step of the ONNX export. I have also provided a script to reproduce this behavior in the comment below
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@rajeevsrao do you still plan to work on this? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Reviving this PR. The main issue observed is that the type of Elaborating using 2 cases below. Please print the Case 1: InferenceDuring inference, the type of
Case 2: ONNX exportAs the error specifically lies with the UNET, I'm exporting just the unet model using the script below. While tracing for the ONNX export,
The PR aims to correct the inconsistency in the type of |
@asfiyab-nvidia would you suggest anything being done differently in this PR? |
The main goal is to align the types. An alternative to the change suggested in the PR is to unconditionally cast the variable to a torch tensor.
However, since the variable |
cc @echarlaix here again I'm fine with the change if agreed it's the best way to support ONNX export |
Hi, following up on this PR. |
@echarlaix a gentle ping. |
@@ -397,6 +397,8 @@ def forward( | |||
|
|||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |||
batch_size, num_frames = sample.shape[:2] | |||
if torch.is_tensor(num_frames): | |||
num_frames = num_frames.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will results in num_frames
to be set as a constant in the ONNX graph, is this expected (I see that it's defined in the unet config) or is there any cases where this value might vary @yiyixuxu @sayakpaul ? If yes we should move the tensor to the expected device instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, num_frames
can change as it's in the pipeline call.
diffusers/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
Line 339 in a1cb106
num_frames: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix you're right, with the suggested change num_frames
is traced as a constant and that is undesirable. Would the below change be more suitable?
if torch.is_tensor(num_frames):
num_frames = num_frames.to(sample.device)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix pinging to follow up on this. I ran the ONNX export of the UNET model using the above fix. The export runs successfully. However, the model fails ONNXRuntime Inference with the error below
ort_s = ort.InferenceSession(model_path)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from svd/svd_unet.onnx failed:Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (/up_blocks.3/attentions.2/time_mixer/Where).
I'd appreciate your input on how we can move this PR along. At the moment, the ONNX export for the SpatioTemporal UNET is broken. There are 2 ways to enable the export
- Make
num_frames
a scalar. This results in num_frames being an ONNX constant and not changeable during inference. - Move
num_frames
to device during ONNX export. The ONNXRuntime Inference fails.
Option 1 sets num_frames
to a Constant, but passes ONNXRuntime Inference - I'm leaning towards this as it results in immediate usability of the exported model. Happy to hear if you have alternative suggestions to the options suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix pinging to follow up on this. I ran the ONNX export of the UNET model using the above fix. The export runs successfully. However, the model fails ONNXRuntime Inference with the error below
Did you try to cast num_frames
to a different data type ? Also were you able to locate where in the graph is this issue coming from ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix num_frames
can only be an integer type as it is the repeats
input to repeat_interleave. int64 and int32 are the only cast options and neither resolve the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix following up on this
Hi @echarlaix @sayakpaul requesting updates based on the latest comments. Thanks |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Could someone review this? Thanks! |
I think this will need to be reviewed by someone from the Optimum team. Cc: @echarlaix again |
What does this PR do?
Unpack num_frames scalar if created as a (CPU) tensor in forward path
Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops