Skip to content

Commit 3dae479

Browse files
committed
Add tutorial for user defined triton kernels
1 parent 6d7a843 commit 3dae479

File tree

1 file changed

+130
-0
lines changed

1 file changed

+130
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Using User Defined Triton Kernels with ``torch.compile``
5+
=================================
6+
**Author:** `Oguz Ulgen <https://github.com/oulgen>`_
7+
"""
8+
9+
######################################################################
10+
# This tutorial explains how to use user defined triton kernels with ``torch.compile``.
11+
#
12+
# .. note::
13+
# This tutorial requires PyTorch 2.3 or later and a GPU that supports Triton.
14+
#
15+
16+
import torch
17+
from torch.utils._triton import has_triton
18+
19+
######################################################################
20+
# Basic Usage
21+
# ------------
22+
#
23+
# In this example, we will use a simple vector addition kernel from the triton documentation
24+
# with ``torch.compile``.
25+
# Reference: https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
26+
#
27+
28+
if not has_triton:
29+
print("Skipping because triton is not supported on this device.")
30+
else:
31+
import triton
32+
from triton import language as tl
33+
34+
@triton.jit
35+
def add_kernel(
36+
in_ptr0,
37+
in_ptr1,
38+
out_ptr,
39+
n_elements,
40+
BLOCK_SIZE: "tl.constexpr",
41+
):
42+
pid = tl.program_id(axis=0)
43+
block_start = pid * BLOCK_SIZE
44+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
45+
mask = offsets < n_elements
46+
x = tl.load(in_ptr0 + offsets, mask=mask)
47+
y = tl.load(in_ptr1 + offsets, mask=mask)
48+
output = x + y
49+
tl.store(out_ptr + offsets, output, mask=mask)
50+
51+
@torch.compile(fullgraph=True)
52+
def add_fn(x, y):
53+
output = torch.zeros_like(x)
54+
n_elements = output.numel()
55+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
56+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
57+
return output
58+
59+
x = torch.randn(4, device="cuda")
60+
y = torch.randn(4, device="cuda")
61+
out = add_fn(x, y)
62+
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
63+
64+
######################################################################
65+
# Advanced Usage
66+
# ------------
67+
#
68+
# It is also possible to triton.autotune with ``torch.compile``.
69+
#
70+
# .. note::
71+
#
72+
# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``.
73+
74+
if not has_triton:
75+
print("Skipping because triton is not supported on this device.")
76+
else:
77+
import triton
78+
from triton import language as tl
79+
80+
@triton.autotune(
81+
configs=[
82+
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
83+
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
84+
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
85+
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
86+
],
87+
key=[],
88+
)
89+
@triton.jit
90+
def add_kernel_autotuned(
91+
in_ptr0,
92+
in_ptr1,
93+
out_ptr,
94+
n_elements,
95+
BLOCK_SIZE: "tl.constexpr",
96+
):
97+
pid = tl.program_id(axis=0)
98+
block_start = pid * BLOCK_SIZE
99+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
100+
mask = offsets < n_elements
101+
x = tl.load(in_ptr0 + offsets, mask=mask)
102+
y = tl.load(in_ptr1 + offsets, mask=mask)
103+
output = x + y
104+
tl.store(out_ptr + offsets, output, mask=mask)
105+
106+
@torch.compile(fullgraph=True)
107+
def add_fn(x, y):
108+
output = torch.zeros_like(x)
109+
n_elements = output.numel()
110+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
111+
add_kernel_autotuned[grid](x, y, output, n_elements)
112+
return output
113+
114+
x = torch.randn(4, device="cuda")
115+
y = torch.randn(4, device="cuda")
116+
out = add_fn(x, y)
117+
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
118+
119+
######################################################################
120+
# Composibility and Limitations
121+
# ------------
122+
#
123+
# As for PyTorch 2.3, the user defined triton kernel support in ``torch.compile``
124+
# composes with dynamic shapes, ``torch.autograd.Function``, JIT inductor and
125+
# AOT inductor.
126+
#
127+
# The support for tensor subclasses and other advanced features currently do
128+
# not exist.
129+
# Support for ``triton.heuristics`` exists when it is used by itself but not
130+
# when it is used in combination with ``triton.autotune``.

0 commit comments

Comments
 (0)