Skip to content

Add tutorial for user defined triton kernels #2783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .jenkins/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
"recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
"prototype_source/gpu_quantization_torchao_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
}
Expand Down
4 changes: 4 additions & 0 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Chatbots
Chen
Colab
Colorectal
Composibility
Conda
Conv
ConvNet
Expand Down Expand Up @@ -270,6 +271,7 @@ approximators
autodiff
autoencoder
autograd
autotune
autotuner
backend
backends
Expand Down Expand Up @@ -303,6 +305,7 @@ composable
concat
conda
config
configs
contrastive
conv
convolutional
Expand Down Expand Up @@ -551,6 +554,7 @@ torchviz
traceback
tradeoff
tradeoffs
triton
uint
umap
uncomment
Expand Down
9 changes: 9 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/compiling_optimizer.html
:tags: Model-Optimization

.. Using User-Defined Triton Kernels with ``torch.compile``

.. customcarditem::
:header: Using User-Defined Triton Kernels with ``torch.compile``
:card_description: Learn how to use user-defined kernels with ``torch.compile``
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/torch_compile_user_defined_triton_kernel_tutorial.html
:tags: Model-Optimization

.. Intel(R) Extension for PyTorch*

.. customcarditem::
Expand Down
171 changes: 171 additions & 0 deletions recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-

"""
Using User-Defined Triton Kernels with ``torch.compile``
=========================================================
**Author:** `Oguz Ulgen <https://github.com/oulgen>`_
"""

######################################################################
# User-defined Triton kernels can be used to optimize specific parts of your
# model's computation. These kernels are written in Triton's language, which is designed
# to make it easier to achieve peak hardware performance. By using user-defined Triton
# kernels with ``torch.compile``, you can integrate these optimized computations into
# your PyTorch model, potentially achieving significant performance improvements.
#
# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``.
#
# Prerequisites
# -------------------
#
# Before starting this recipe, make sure that you have the following:
#
# * Basic understanding of ``torch.compile`` and Triton. See:
#
# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__
# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
# * `Triton language documentation <https://triton-lang.org/main/index.html>`__
#
# * PyTorch 2.3 or later
# * A GPU that supports Triton
#

import torch
from torch.utils._triton import has_triton

######################################################################
# Basic Usage
# --------------------
#
# In this example, we will use a simple vector addition kernel from the Triton documentation
# with ``torch.compile``.
# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__.
#

if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl

@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output

x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

######################################################################
# Advanced Usage
# -------------------------------------------------------------------
#
# Triton's autotune feature is a powerful tool that automatically optimizes the configuration
# parameters of your Triton kernels. It explores a range of possible configurations and
# selects the one that delivers the best performance for your specific use case.
#
# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch
# model is running as efficiently as possible. Here is an example of using ``torch.compile``
# and ``triton.autotune``.
#
# .. note::
#
# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``.

if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl

@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output

x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

######################################################################
# Composibility and Limitations
# --------------------------------------------------------------------
#
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
# You can use these features together to build complex, high-performance models.
#
# However, there are certain limitations to be aware of:
#
# * **Tensor Subclasses:** Currently, there is no support for
# tensor subclasses and other advanced features.
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
# together, ``triton.heuristics`` must be used first.
#
# Conclusion
# -----------
# In this recipe, we explored how to utilize user-defined Triton kernels
# with ``torch.compile``. We delved into the basic usage of a simple
# vector addition kernel and advanced usage involving Triton's autotune
# feature. We also discussed the composability of user-defined Triton
# kernels with other PyTorch features and highlighted some current limitations.
#
# See Also
# ---------
#
# * `Compiling the Optimizers: <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__
# * `Implementing High-Performance Transformers with Scaled Dot Product Attention<https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`__