Skip to content

Enable ONNX export of GPU loaded SVD/SVD-XT UNet models #6562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

rajeevsrao
Copy link

@rajeevsrao rajeevsrao commented Jan 13, 2024

What does this PR do?

Unpack num_frames scalar if created as a (CPU) tensor in forward path
Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops

File "/usr/local/lib/python3.10/dist-packages/diffusers/models/unet_spatio_temporal_condition.py", line 422, in forward
    emb = emb.repeat_interleave(num_frames, dim=0)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

@rajeevsrao
Copy link
Author

@patrickvonplaten @sayakpaul please review.

@sayakpaul
Copy link
Member

It should reside in optimum. Cc: @echarlaix

@echarlaix
Copy link
Contributor

Hi @rajeevsrao, could you share the script you used for the export ?

@echarlaix
Copy link
Contributor

It should reside in optimum. Cc: @echarlaix

You mean patching the model in optimum ? Depending on the modifications needed, it could make sense to have it in diffusers instead

@rajeevsrao
Copy link
Author

rajeevsrao commented Jan 22, 2024

It should reside in optimum. Cc: @echarlaix

You mean patching the model in optimum ? Depending on the modifications needed, it could make sense to have it in diffusers instead

ONNX is a popular interchange format. @sayakpaul I think that diffusers should also support exporting models into ONNX. Especially given that this is a easy/harmless fix.

@rajeevsrao
Copy link
Author

Hi @rajeevsrao, could you share the script you used for the export ?

Here is the ONNX export script for reference

from diffusers.models import UNetSpatioTemporalConditionModel
import torch

model_name = "svd"
if model_name == "svd-xt":
    pipeline = 'stabilityai/stable-video-diffusion-img2vid-xt'
    num_frames = 25
else:
    pipeline = 'stabilityai/stable-video-diffusion-img2vid'
    num_frames = 14

device = 'cuda'
dtype = torch.float16
model = UNetSpatioTemporalConditionModel.from_pretrained(pipeline,
    subfolder="unet",
    use_safetensors=True,
    variant='fp16',
    torch_dtype=dtype).to(device)

batch_size = 2
out_channels = 4
cross_attention_dim = 1024
latent_height = 576 // 8
latent_width = 1024 // 8

input_names = ['sample', 'timestep', 'encoder_hidden_states', 'added_time_ids']
inputs = (
    torch.randn(batch_size, num_frames, 2*out_channels, latent_height, latent_width, dtype=dtype, device=device),
    torch.tensor([1.], dtype=torch.float32, device=device),
    torch.randn(batch_size, 1, cross_attention_dim, dtype=dtype, device=device),
    torch.randn(batch_size, 3, dtype=dtype, device=device),
)
output_names = ['latent']
dynamic_axes = {
    'sample': {0: '2B', 1: 'num_frames', 3: 'H', 4: 'W'},
    'encoder_hidden_states': {0: '2B'},
    'added_time_ids': {0: '2B'}
}

with torch.inference_mode(), torch.autocast(device):
    torch.onnx.export(model,
        inputs,
        model_name+"_unet.onnx",
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )



* Unpack num_frames scalar if created as a (CPU) tensor in forward path
  Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops

Signed-off-by: Rajeev Rao <[email protected]>
@rajeevsrao
Copy link
Author

@sayakpaul @echarlaix please suggest next steps. Thanks.

@@ -397,6 +397,8 @@ def forward(

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
if torch.is_tensor(num_frames):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can num_frames ever be a tensor? Can you give an example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten num_frames is created as a CPU tensor during the tracing step of the ONNX export. I have also provided a script to reproduce this behavior in the comment below

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 19, 2024
@sayakpaul
Copy link
Member

@rajeevsrao do you still plan to work on this?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asfiyab-nvidia
Copy link
Contributor

Reviving this PR.

The main issue observed is that the type of num_frames changes based on the operation performed.

Elaborating using 2 cases below. Please print the type(num_frames) here to investigate further

Case 1: Inference

During inference, the type of num_frames is <class 'int'>. Inference script used:

import torch
from diffusers.utils import load_image
from diffusers import StableVideoDiffusionPipeline
pipe = StableVideoDiffusionPipeline.from_pretrained('stabilityai/stable-video-diffusion-img2vid', torch_dtype=torch.float16, variant="fp16").to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
image = image.resize((1024, 576))
frames = pipe(image, decode_chunk_size=8).frames[0]

Case 2: ONNX export

As the error specifically lies with the UNET, I'm exporting just the unet model using the script below. While tracing for the ONNX export, num_frames is created as a <class 'torch.Tensor'> on the CPU. Running the export on GPU results in the error from the description.

import torch
from diffusers.models import UNetSpatioTemporalConditionModel

dtype=torch.float16
device='cuda'
model = UNetSpatioTemporalConditionModel.from_pretrained('stabilityai/stable-video-diffusion-img2vid',
    subfolder="unet",
    use_safetensors=True,
    variant='fp16',
    torch_dtype=dtype).to(device)

batch_size = 2
num_frames=14
out_channels = 4
cross_attention_dim = 1024
latent_height = 576 // 8
latent_width = 1024 // 8

inputs = (
    torch.randn(batch_size, num_frames, 2*out_channels, latent_height, latent_width, dtype=dtype, device=device),
    torch.tensor([1.], dtype=torch.float32, device=device),
    torch.randn(batch_size, 1, cross_attention_dim, dtype=dtype, device=device),
    torch.randn(batch_size, 3, dtype=dtype, device=device),
)

with torch.inference_mode(), torch.autocast(device):
    torch.onnx.export(model,
        inputs,
        "svd/svd_unet.onnx",
        export_params=True,
        opset_version=18,
        do_constant_folding=True
    )

The PR aims to correct the inconsistency in the type of num_frames during inference and tracing

@sayakpaul
Copy link
Member

Cc: @yiyixuxu. I am okay with the changes here since ONNX is very popular. LMK. @DN6, you too.

@sayakpaul
Copy link
Member

@asfiyab-nvidia would you suggest anything being done differently in this PR?

@asfiyab-nvidia
Copy link
Contributor

@asfiyab-nvidia would you suggest anything being done differently in this PR?

The main goal is to align the types. An alternative to the change suggested in the PR is to unconditionally cast the variable to a torch tensor.

num_frames = torch.tensor(num_frames).to(sample.device)

However, since the variable num_frames has usage in the context of being a scalar, I'd vote for the recommendation in the PR to cast to a scalar if found to be a tensor.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 1, 2024

cc @echarlaix here again

I'm fine with the change if agreed it's the best way to support ONNX export

@github-actions github-actions bot removed the stale Issues that haven't received updates label Mar 2, 2024
@asfiyab-nvidia
Copy link
Contributor

Hi, following up on this PR.

@sayakpaul
Copy link
Member

@echarlaix a gentle ping.

@yiyixuxu yiyixuxu added the ONNX label Mar 9, 2024
@@ -397,6 +397,8 @@ def forward(

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
if torch.is_tensor(num_frames):
num_frames = num_frames.item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will results in num_frames to be set as a constant in the ONNX graph, is this expected (I see that it's defined in the unet config) or is there any cases where this value might vary @yiyixuxu @sayakpaul ? If yes we should move the tensor to the expected device instead

Copy link
Member

@sayakpaul sayakpaul Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, num_frames can change as it's in the pipeline call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix you're right, with the suggested change num_frames is traced as a constant and that is undesirable. Would the below change be more suitable?

        if torch.is_tensor(num_frames):
            num_frames = num_frames.to(sample.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix pinging to follow up on this. I ran the ONNX export of the UNET model using the above fix. The export runs successfully. However, the model fails ONNXRuntime Inference with the error below

ort_s = ort.InferenceSession(model_path)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from svd/svd_unet.onnx failed:Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (/up_blocks.3/attentions.2/time_mixer/Where).

I'd appreciate your input on how we can move this PR along. At the moment, the ONNX export for the SpatioTemporal UNET is broken. There are 2 ways to enable the export

  1. Make num_frames a scalar. This results in num_frames being an ONNX constant and not changeable during inference.
  2. Move num_frames to device during ONNX export. The ONNXRuntime Inference fails.

Option 1 sets num_frames to a Constant, but passes ONNXRuntime Inference - I'm leaning towards this as it results in immediate usability of the exported model. Happy to hear if you have alternative suggestions to the options suggested.

Copy link
Contributor

@echarlaix echarlaix Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix pinging to follow up on this. I ran the ONNX export of the UNET model using the above fix. The export runs successfully. However, the model fails ONNXRuntime Inference with the error below

Did you try to cast num_frames to a different data type ? Also were you able to locate where in the graph is this issue coming from ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix num_frames can only be an integer type as it is the repeats input to repeat_interleave. int64 and int32 are the only cast options and neither resolve the issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix following up on this

@asfiyab-nvidia
Copy link
Contributor

Hi @echarlaix @sayakpaul requesting updates based on the latest comments. Thanks

Copy link
Contributor

github-actions bot commented May 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 3, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label May 3, 2024
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@LeoZDong
Copy link

LeoZDong commented Mar 7, 2025

Could someone review this? Thanks!

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul March 7, 2025 19:17
@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Mar 8, 2025
@sayakpaul sayakpaul requested a review from echarlaix March 8, 2025 03:02
@sayakpaul
Copy link
Member

I think this will need to be reviewed by someone from the Optimum team. Cc: @echarlaix again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants