diff --git a/docsrc/index.rst b/docsrc/index.rst index 1c0c9a0d9e..946407433c 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -110,6 +110,7 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_resnet_example tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage + tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion Python API Documenation ------------------------ @@ -206,4 +207,4 @@ Legacy Further Information (TorchScript) * `GTC 2021 Fall Talk `_ * `PyTorch Ecosystem Day 2021 `_ * `PyTorch Developer Conference 2021 `_ -* `PyTorch Developer Conference 2022 `_ \ No newline at end of file +* `PyTorch Developer Conference 2022 `_ diff --git a/docsrc/tutorials/images/majestic_castle.png b/docsrc/tutorials/images/majestic_castle.png new file mode 100644 index 0000000000..bac6073a90 Binary files /dev/null and b/docsrc/tutorials/images/majestic_castle.png differ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index fa863952e7..d895cc0113 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API +* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` diff --git a/examples/dynamo/torch_compile_stable_diffusion.py b/examples/dynamo/torch_compile_stable_diffusion.py new file mode 100644 index 0000000000..0511e5a363 --- /dev/null +++ b/examples/dynamo/torch_compile_stable_diffusion.py @@ -0,0 +1,55 @@ +""" +.. _torch_compile_stable_diffusion: + +Torch Compile Stable Diffusion +====================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a Stable Diffusion model. A sample output is featured below: + +.. image:: /tutorials/images/majestic_castle.png + :width: 512px + :height: 512px + :scale: 50 % + :align: right +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +from diffusers import DiffusionPipeline + +import torch_tensorrt + +model_id = "CompVis/stable-diffusion-v1-4" +device = "cuda:0" + +# Instantiate Stable Diffusion Pipeline with FP16 weights +pipe = DiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16 +) +pipe = pipe.to(device) + +backend = "torch_tensorrt" + +# Optimize the UNet portion with Torch-TensorRT +pipe.unet = torch.compile( + pipe.unet, + backend=backend, + options={ + "truncate_long_and_double": True, + "precision": torch.float16, + }, + dynamic=False, +) + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +prompt = "a majestic castle in the clouds" +image = pipe(prompt).images[0] + +image.save("images/majestic_castle.png") +image.show()