|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +Using User Defined Triton Kernels with ``torch.compile`` |
| 5 | +================================= |
| 6 | +**Author:** `Oguz Ulgen <https://github.com/oulgen>`_ |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# This tutorial explains how to use user defined triton kernels with ``torch.compile``. |
| 11 | +# |
| 12 | +# .. note:: |
| 13 | +# This tutorial requires PyTorch 2.3 or later and a GPU that supports Triton. |
| 14 | +# |
| 15 | + |
| 16 | +import torch |
| 17 | +from torch.utils._triton import has_triton |
| 18 | + |
| 19 | +###################################################################### |
| 20 | +# Basic Usage |
| 21 | +# ------------ |
| 22 | +# |
| 23 | +# In this example, we will use a simple vector addition kernel from the triton documentation |
| 24 | +# with ``torch.compile``. |
| 25 | +# Reference: https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html |
| 26 | +# |
| 27 | + |
| 28 | +if not has_triton: |
| 29 | + print("Skipping because triton is not supported on this device.") |
| 30 | +else: |
| 31 | + import triton |
| 32 | + from triton import language as tl |
| 33 | + |
| 34 | + @triton.jit |
| 35 | + def add_kernel( |
| 36 | + in_ptr0, |
| 37 | + in_ptr1, |
| 38 | + out_ptr, |
| 39 | + n_elements, |
| 40 | + BLOCK_SIZE: "tl.constexpr", |
| 41 | + ): |
| 42 | + pid = tl.program_id(axis=0) |
| 43 | + block_start = pid * BLOCK_SIZE |
| 44 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 45 | + mask = offsets < n_elements |
| 46 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 47 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 48 | + output = x + y |
| 49 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 50 | + |
| 51 | + @torch.compile(fullgraph=True) |
| 52 | + def add_fn(x, y): |
| 53 | + output = torch.zeros_like(x) |
| 54 | + n_elements = output.numel() |
| 55 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 56 | + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) |
| 57 | + return output |
| 58 | + |
| 59 | + x = torch.randn(4, device="cuda") |
| 60 | + y = torch.randn(4, device="cuda") |
| 61 | + out = add_fn(x, y) |
| 62 | + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
| 63 | + |
| 64 | +###################################################################### |
| 65 | +# Advanced Usage |
| 66 | +# ------------ |
| 67 | +# |
| 68 | +# It is also possible to triton.autotune with ``torch.compile``. |
| 69 | +# |
| 70 | +# .. note:: |
| 71 | +# |
| 72 | +# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``. |
| 73 | + |
| 74 | +if not has_triton: |
| 75 | + print("Skipping because triton is not supported on this device.") |
| 76 | +else: |
| 77 | + import triton |
| 78 | + from triton import language as tl |
| 79 | + |
| 80 | + @triton.autotune( |
| 81 | + configs=[ |
| 82 | + triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8), |
| 83 | + triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4), |
| 84 | + triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8), |
| 85 | + triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4), |
| 86 | + ], |
| 87 | + key=[], |
| 88 | + ) |
| 89 | + @triton.jit |
| 90 | + def add_kernel_autotuned( |
| 91 | + in_ptr0, |
| 92 | + in_ptr1, |
| 93 | + out_ptr, |
| 94 | + n_elements, |
| 95 | + BLOCK_SIZE: "tl.constexpr", |
| 96 | + ): |
| 97 | + pid = tl.program_id(axis=0) |
| 98 | + block_start = pid * BLOCK_SIZE |
| 99 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 100 | + mask = offsets < n_elements |
| 101 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 102 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 103 | + output = x + y |
| 104 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 105 | + |
| 106 | + @torch.compile(fullgraph=True) |
| 107 | + def add_fn(x, y): |
| 108 | + output = torch.zeros_like(x) |
| 109 | + n_elements = output.numel() |
| 110 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 111 | + add_kernel_autotuned[grid](x, y, output, n_elements) |
| 112 | + return output |
| 113 | + |
| 114 | + x = torch.randn(4, device="cuda") |
| 115 | + y = torch.randn(4, device="cuda") |
| 116 | + out = add_fn(x, y) |
| 117 | + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
| 118 | + |
| 119 | +###################################################################### |
| 120 | +# Composibility and Limitations |
| 121 | +# ------------ |
| 122 | +# |
| 123 | +# As for PyTorch 2.3, the user defined triton kernel support in ``torch.compile`` |
| 124 | +# composes with dynamic shapes, ``torch.autograd.Function``, JIT inductor and |
| 125 | +# AOT inductor. |
| 126 | +# |
| 127 | +# The support for tensor subclasses and other advanced features currently do |
| 128 | +# not exist. |
| 129 | +# Support for ``triton.heuristics`` exists when it is used by itself but not |
| 130 | +# when it is used in combination with ``triton.autotune``. |
0 commit comments