Skip to content

Commit 8553d08

Browse files
committed
Add tutorial about inductor caching
ghstack-source-id: 380379506af15164aeea1456ffa437ca2f5d1b33 Pull Request resolved: #2951
1 parent f2b8a1b commit 8553d08

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

Diff for: recipes_source/recipes_index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
317317
:link: ../recipes/torch_compile_user_defined_triton_kernel_tutorial.html
318318
:tags: Model-Optimization
319319

320+
.. Compile Time Caching in ``torch.compile``
321+
322+
.. customcarditem::
323+
:header: Compile Time Caching in ``torch.compile``
324+
:card_description: Learn how to configure compile time caching in ``torch.compile``
325+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
326+
:link: ../recipes/torch_compile_caching_tutorial.html
327+
:tags: Model-Optimization
328+
320329
.. Intel(R) Extension for PyTorch*
321330
322331
.. customcarditem::

Diff for: recipes_source/torch_compile_caching_tutorial.rst

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
Compile Time Caching in ``torch.compile``
2+
=========================================================
3+
**Authors:** `Oguz Ulgen <https://github.com/oulgen>`_ and `Sam Larsen <https://github.com/masnesral>`_
4+
5+
Introduction
6+
------------------
7+
8+
PyTorch Inductor implements several caches to reduce compilation latency. These caches are transparent to the user.
9+
This recipes demonstrates how you to configure various parts of the caching in ``torch.compile``.
10+
11+
Prerequisites
12+
-------------------
13+
14+
Before starting this recipe, make sure that you have the following:
15+
16+
* Basic understanding of ``torch.compile``. See:
17+
18+
* `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__
19+
* `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
20+
21+
* PyTorch 2.4 or later
22+
23+
Inductor Cache Settings
24+
----------------------------
25+
26+
Most of these caches are in-memory, only used within the same process, and are transparent to the user. An exception is the FX graph cache that stores compiled FX graphs. This cache allows Inductor to avoid recompilation across process boundaries when it encounters the same graph with the same Tensor input shapes (and the same configuration, etc.). The default implementation stores compiled artifacts in the system temp directory. An optional feature also supports sharing those artifacts within a cluster by storing them in Redis.
27+
28+
There are a few settings relevant to caching and to FX graph caching in particular. The settings are accessible via environment variables, or can be hard-coded in Inductor’s config file.
29+
30+
TORCHINDUCTOR_FX_GRAPH_CACHE
31+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32+
This setting enables the local FX graph cache feature, i.e., by storing artifacts on the host’s temp directory. ``1`` enables, and any other value disables. By default, the disk location is per username, but users can enable sharing across usernames by specifying ``TORCHINDUCTOR_CACHE_DIR`` (below).
33+
34+
TORCHINDUCTOR_CACHE_DIR
35+
~~~~~~~~~~~~~~~~~~~~~~~~
36+
This setting specifies the location of all on-disk caches. By default, the location is in the system temp directory under ``torchinductor_<username>``, e.g., ``/tmp/torchinductor_myusername``.
37+
38+
Note that if ``TRITON_CACHE_DIR`` is not set in the environment, Inductor sets the Triton cache directory to this same temp location (under the Triton subdirectory).
39+
40+
TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE
41+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42+
This setting enables the remote FX graph cache feature. The current implementation uses Redis. ``1`` enables, and any other value disables. The following environment variables configure the host and port of the Redis server:
43+
44+
``TORCHINDUCTOR_REDIS_HOST`` (defaults to ``localhost``)
45+
``TORCHINDUCTOR_REDIS_PORT`` (defaults to ``6379``)
46+
47+
Note that if Inductor locates a remote cache entry, it stores the compiled artifact in the local on-disk cache; that local artifact would be served on subsequent runs on the same machine.
48+
49+
TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE
50+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
This setting enables a remote cache for Inductor’s autotuner. As with the remote FX graph cache, the current implementation uses Redis. ``1`` enables, and any other value disables. The same host / port environment variables listed above apply to this cache.
52+
53+
TORCHINDUCTOR_FORCE_DISABLE_CACHES
54+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55+
Set this value to ``1`` to disable all Inductor caching. This setting is useful to, e.g., experiment with cold-start compile time, or to force recompilation for debugging purposes.

0 commit comments

Comments
 (0)