diff --git a/docs/source/en/optimization/torch2.0.md b/docs/source/en/optimization/torch2.0.md index 01ea00310a75..cc69eceff3af 100644 --- a/docs/source/en/optimization/torch2.0.md +++ b/docs/source/en/optimization/torch2.0.md @@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [ > [!TIP] > Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial. +### Regional compilation + +Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call. + +Enabling regional compilation might require simple yet intrusive changes to the +modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles +the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation. + +```py +# Make sure you're on the latest `accelerate`: `pip install -U accelerate`. +from accelerate.utils import compile_regions + +pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True) +``` + +As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility. + ## Benchmark We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).