Skip to content

feat: Add FLUX-1.dev model to the model zoo #3382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ Model Zoo
* :ref:`torch_export_gpt2`
* :ref:`torch_export_llama2`
* :ref:`torch_export_sam2`
* :ref:`torch_export_flux_dev`
* :ref:`notebooks`

.. toctree::
Expand All @@ -157,6 +158,7 @@ Model Zoo
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2
tutorials/_rendered_examples/dynamo/torch_export_sam2
tutorials/_rendered_examples/dynamo/torch_export_flux_dev
tutorials/notebooks

Python API Documentation
Expand Down
Binary file added docsrc/tutorials/_rendered_examples/dog_code.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Model Zoo
* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile``
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
150 changes: 150 additions & 0 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
.. _torch_export_flux_dev:

Compiling FLUX.1-dev model using the Torch-TensorRT dynamo backend
===================================================================

This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
Torch-TensorRT.

**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.

Install the following dependencies before compilation

.. code-block:: python

pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"

There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
"""

# %%
# Import the following libraries
# -----------------------------
import torch
import torch_tensorrt
from diffusers import FluxPipeline
from torch.export._trace import _export

# %%
# Define the FLUX-1.dev model
# -----------------------------
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(DEVICE).to(torch.float16)
# Store the config and transformer backbone
config = pipe.transformer.config
backbone = pipe.transformer


# %%
# Export the backbone using torch.export
# --------------------------------------------------
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {0: SEQ_LEN},
"img_ids": {0: IMG_ID},
"guidance": {0: BATCH},
}
# The guidance factor is of type torch.float32
dummy_inputs = {
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
DEVICE
),
"encoder_hidden_states": torch.randn(
(batch_size, 512, 4096), dtype=torch.float16
).to(DEVICE),
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
DEVICE
),
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
backbone,
args=(),
kwargs=dummy_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)

# %%
# Torch-TensorRT compilation
# ---------------------------
# .. note::
# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future.
#
#
# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes.
# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation.
# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in
# a single TensorRT engine.
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=dummy_inputs,
enabled_precisions={torch.float32},
truncate_double=True,
min_block_size=1,
use_fp32_acc=True,
use_explicit_typing=True,
)

# %%
# Post Processing
# ---------------------------
# Release the GPU memory occupied by the exported program and the pipe.transformer
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
backbone.to("cpu")
del ep
pipe.transformer = trt_gm
pipe.transformer.config = config

# %%
# Image generation using prompt
# ---------------------------
# Provide a prompt and the file name of the image to be generated. Here we use the
# prompt ``A golden retriever holding a sign to code``.


# Function which generates images from the flux pipeline
def generate_image(pipe, prompt, image_name):
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cuda").manual_seed(seed),
).images[0]
image.save(f"{image_name}.png")
print(f"Image generated using {image_name} model saved as {image_name}.png")


generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")

# %%
# The generated image is as shown below
#
# .. image:: dog_code.png
#
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def remove_assert_scalar(
"""Remove assert_scalar ops in the graph"""
count = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten._assert_scalar.default:
if (
node.target == torch.ops.aten._assert_scalar.default
or node == torch.ops.aten._assert_tensor_metadata.default
):
gm.graph.erase_node(node)
count += 1

Expand Down
17 changes: 12 additions & 5 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,16 @@ def prepare_inputs(
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
disable_memory_format_check: bool = False,
) -> Any:
if isinstance(inputs, Input):
if inputs is None:
return None

elif isinstance(inputs, Input):
return inputs

elif isinstance(inputs, torch.Tensor):
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
return Input.from_tensor(
inputs, disable_memory_format_check=disable_memory_format_check
torch.tensor(inputs),
disable_memory_format_check=disable_memory_format_check,
)

elif isinstance(inputs, (list, tuple)):
Expand Down Expand Up @@ -395,10 +399,13 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
"""
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
"""
if isinstance(tensor, (torch.Tensor, FakeTensor)):
return tensor.dtype
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
return torch.tensor(tensor).dtype
elif isinstance(tensor, torch.SymInt):
return torch.int64
elif tensor is None:
# Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev)
return None
else:
raise ValueError(f"Found invalid tensor type {type(tensor)}")

Expand Down
Loading