From e4ff64eb1e3772d6ec0319ca99834a93ac4254d1 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 1 Mar 2024 12:32:07 -0800 Subject: [PATCH 01/12] Add tutorial for user defined triton kernels --- en-wordlist.txt | 4 + ...ile_user_defined_triton_kernel_tutorial.py | 130 ++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py diff --git a/en-wordlist.txt b/en-wordlist.txt index bffd5f84d91..2b17f5a55c2 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -34,6 +34,7 @@ Chatbots Chen Colab Colorectal +Composibility Conda Conv ConvNet @@ -270,6 +271,7 @@ approximators autodiff autoencoder autograd +autotune autotuner backend backends @@ -303,6 +305,7 @@ composable concat conda config +configs contrastive conv convolutional @@ -551,6 +554,7 @@ torchviz traceback tradeoff tradeoffs +triton uint umap uncomment diff --git a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py new file mode 100644 index 00000000000..b60ad6bbaa6 --- /dev/null +++ b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- + +""" +Using User Defined Triton Kernels with ``torch.compile`` +================================= +**Author:** `Oguz Ulgen `_ +""" + +###################################################################### +# This tutorial explains how to use user defined triton kernels with ``torch.compile``. +# +# .. note:: +# This tutorial requires PyTorch 2.3 or later and a GPU that supports Triton. +# + +import torch +from torch.utils._triton import has_triton + +###################################################################### +# Basic Usage +# ------------ +# +# In this example, we will use a simple vector addition kernel from the triton documentation +# with ``torch.compile``. +# Reference: https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html +# + +if not has_triton: + print("Skipping because triton is not supported on this device.") +else: + import triton + from triton import language as tl + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @torch.compile(fullgraph=True) + def add_fn(x, y): + output = torch.zeros_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) + return output + + x = torch.randn(4, device="cuda") + y = torch.randn(4, device="cuda") + out = add_fn(x, y) + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") + +###################################################################### +# Advanced Usage +# ------------ +# +# It is also possible to triton.autotune with ``torch.compile``. +# +# .. note:: +# +# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``. + +if not has_triton: + print("Skipping because triton is not supported on this device.") +else: + import triton + from triton import language as tl + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + @torch.compile(fullgraph=True) + def add_fn(x, y): + output = torch.zeros_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel_autotuned[grid](x, y, output, n_elements) + return output + + x = torch.randn(4, device="cuda") + y = torch.randn(4, device="cuda") + out = add_fn(x, y) + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") + +###################################################################### +# Composibility and Limitations +# ------------ +# +# As for PyTorch 2.3, the user defined triton kernel support in ``torch.compile`` +# composes with dynamic shapes, ``torch.autograd.Function``, JIT inductor and +# AOT inductor. +# +# The support for tensor subclasses and other advanced features currently do +# not exist. +# Support for ``triton.heuristics`` exists when it is used by itself but not +# when it is used in combination with ``triton.autotune``. From a3f7939c1075735e564463843841c993178a03f5 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Tue, 19 Mar 2024 15:52:57 -0700 Subject: [PATCH 02/12] Update intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py Merge a small fix to kick off the build --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py index b60ad6bbaa6..ea6307d5f9c 100644 --- a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -20,7 +20,7 @@ # Basic Usage # ------------ # -# In this example, we will use a simple vector addition kernel from the triton documentation +# In this example, we will use a simple vector addition kernel from the Triton documentation # with ``torch.compile``. # Reference: https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html # From 565e2c68e1c69d50b051c743272453d7a637ec6f Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 21 Mar 2024 10:09:12 -0700 Subject: [PATCH 03/12] Update intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py small change to kick off the build --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py index ea6307d5f9c..3eee526bccc 100644 --- a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -7,7 +7,7 @@ """ ###################################################################### -# This tutorial explains how to use user defined triton kernels with ``torch.compile``. +# This tutorial explains how to use user defined Triton kernels with ``torch.compile``. # # .. note:: # This tutorial requires PyTorch 2.3 or later and a GPU that supports Triton. From caa4c6012eafc7f9c4bc53940d8191189b3185d2 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 21 Mar 2024 13:15:48 -0700 Subject: [PATCH 04/12] update --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py index 3eee526bccc..3a78c29238a 100644 --- a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -126,5 +126,6 @@ def add_fn(x, y): # # The support for tensor subclasses and other advanced features currently do # not exist. -# Support for ``triton.heuristics`` exists when it is used by itself but not -# when it is used in combination with ``triton.autotune``. +# Support for ``triton.heuristics`` exists when it is used by itself or before +# ``triton.autotune``; however, support for using ``triton.heuristic`` after +# ``triton.autotune`` is not yet supported. From 538ed7d4218a97f594c421fe5ead89e444513b2a Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 21 Mar 2024 16:20:38 -0700 Subject: [PATCH 05/12] update --- .jenkins/metadata.json | 3 +++ .../torch_compile_user_defined_triton_kernel_tutorial.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) rename {intermediate_source => recipes_source}/torch_compile_user_defined_triton_kernel_tutorial.py (99%) diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index a039b63f17e..599696fcdc9 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -42,6 +42,9 @@ "intermediate_source/scaled_dot_product_attention_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, + "recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "prototype_source/gpu_quantization_torchao_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" } diff --git a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py similarity index 99% rename from intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py rename to recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 3a78c29238a..a13100d2045 100644 --- a/intermediate_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -25,7 +25,7 @@ # Reference: https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html # -if not has_triton: +if not has_triton(): print("Skipping because triton is not supported on this device.") else: import triton @@ -71,7 +71,7 @@ def add_fn(x, y): # # ``torch.compile`` only supports configs and key arguments to ``triton.autotune``. -if not has_triton: +if not has_triton(): print("Skipping because triton is not supported on this device.") else: import triton From 9d05ab47f944806fbf05d5905db493972fec75e0 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 22 Mar 2024 10:09:57 -0700 Subject: [PATCH 06/12] update --- recipes_source/recipes_index.rst | 9 +++ ...ile_user_defined_triton_kernel_tutorial.py | 65 ++++++++++++++----- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index 92863bcec50..e7119c276ba 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -300,6 +300,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/compiling_optimizer.html :tags: Model-Optimization +.. Using User-Defined Triton Kernels with ``torch.compile`` + +.. customcarditem:: + :header: Using User-Defined Triton Kernels with ``torch.compile`` + :card_description: Learn how to use user-defined kernels with ``torch.compile`` + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/torch_compile_user_defined_triton_kernel_tutorial.html + :tags: Model-Optimization + .. Intel(R) Extension for PyTorch* .. customcarditem:: diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index a13100d2045..c0748f95ed9 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -1,16 +1,32 @@ # -*- coding: utf-8 -*- """ -Using User Defined Triton Kernels with ``torch.compile`` +Using User-Defined Triton Kernels with ``torch.compile`` ================================= **Author:** `Oguz Ulgen `_ """ ###################################################################### -# This tutorial explains how to use user defined Triton kernels with ``torch.compile``. +# This tutorial explains how to use user-defined Triton kernels with ``torch.compile``. +# User-defined Triton kernels can be used to optimize specific parts of your +# model's computation. These kernels are written in Triton's language, which is designed +# to make it easier to achieve peak hardware performance. By using user-defined Triton +# kernels with ``torch.compile``, you can integrate these optimized computations into +# your PyTorch model, potentially achieving significant performance improvements. # -# .. note:: -# This tutorial requires PyTorch 2.3 or later and a GPU that supports Triton. +# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``. +# +# Prerequisites +# ------------------- +# +# Before starting this recipe, make sure that you have the following: +# +# * Basic understanding of ``torch.compile`` and Triton. See: +# * `torch.compiler API documentation `__ +# * `Introduction to torch.compile `__ +# * `Triton language documentation `__ +# * PyTorch 2.3 or later +# * A GPU that supports Triton # import torch @@ -22,7 +38,7 @@ # # In this example, we will use a simple vector addition kernel from the Triton documentation # with ``torch.compile``. -# Reference: https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html +# For reference, see `Triton documentation `__. # if not has_triton(): @@ -63,9 +79,15 @@ def add_fn(x, y): ###################################################################### # Advanced Usage -# ------------ +# ------------------------------------------------------------------- # -# It is also possible to triton.autotune with ``torch.compile``. +# Triton's autotune feature is a powerful tool that automatically optimizes the configuration +# parameters of your Triton kernels. It explores a range of possible configurations and +# selects the one that delivers the best performance for your specific use case. +# +# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch +# model is running as efficiently as possible. Here is an example of using ``torch.compile`` +# and ``triton.autotune``. # # .. note:: # @@ -118,14 +140,25 @@ def add_fn(x, y): ###################################################################### # Composibility and Limitations -# ------------ +# -------------------------------------------------------------------- +# +# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` +# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. +# You can use these features together to build complex, high-performance models. +# +# However, there are certain limitations to be aware of: # -# As for PyTorch 2.3, the user defined triton kernel support in ``torch.compile`` -# composes with dynamic shapes, ``torch.autograd.Function``, JIT inductor and -# AOT inductor. +# * **Tensor Subclasses:** Currently, there is no support for +# tensor subclasses and other advanced features. +# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or +# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This +# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used +# together, ``triton.heuristics`` must be used first. # -# The support for tensor subclasses and other advanced features currently do -# not exist. -# Support for ``triton.heuristics`` exists when it is used by itself or before -# ``triton.autotune``; however, support for using ``triton.heuristic`` after -# ``triton.autotune`` is not yet supported. +# Conclusion +# ------------------------------------------------------------------- +# In this recipe, we explored how to utilize user-defined Triton kernels +# with ``torch.compile``. We delved into the basic usage of a simple +# vector addition kernel and advanced usage involving Triton's autotune +# feature. We also discussed the composability of user-defined Triton +# kernels with other PyTorch features and highlighted some current limitations. From 5ca9fbf057b87615a67d7aa37131ef1fe827d066 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 22 Mar 2024 10:46:04 -0700 Subject: [PATCH 07/12] update --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index c0748f95ed9..dd7d4454f4d 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -34,7 +34,7 @@ ###################################################################### # Basic Usage -# ------------ +# -------------------- # # In this example, we will use a simple vector addition kernel from the Triton documentation # with ``torch.compile``. @@ -162,3 +162,9 @@ def add_fn(x, y): # vector addition kernel and advanced usage involving Triton's autotune # feature. We also discussed the composability of user-defined Triton # kernels with other PyTorch features and highlighted some current limitations. +# +# See Also +# ------------------------------------------------------------------- +# +# * `Compiling the Optimizers: `__. +# * `Implementing High-Performance Transformers with Scaled Dot Product Attention: `__. From bee572f713d0e2a75134e2c5e12f4e375910e4a3 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 22 Mar 2024 12:25:55 -0700 Subject: [PATCH 08/12] Update torch_compile_user_defined_triton_kernel_tutorial.py Minor editorial fixes. --- ...compile_user_defined_triton_kernel_tutorial.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index dd7d4454f4d..ab67b356448 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -2,12 +2,11 @@ """ Using User-Defined Triton Kernels with ``torch.compile`` -================================= +========================================================= **Author:** `Oguz Ulgen `_ """ ###################################################################### -# This tutorial explains how to use user-defined Triton kernels with ``torch.compile``. # User-defined Triton kernels can be used to optimize specific parts of your # model's computation. These kernels are written in Triton's language, which is designed # to make it easier to achieve peak hardware performance. By using user-defined Triton @@ -22,9 +21,11 @@ # Before starting this recipe, make sure that you have the following: # # * Basic understanding of ``torch.compile`` and Triton. See: -# * `torch.compiler API documentation `__ -# * `Introduction to torch.compile `__ -# * `Triton language documentation `__ +# +# * `torch.compiler API documentation `__ +# * `Introduction to torch.compile `__ +# * `Triton language documentation `__ +# # * PyTorch 2.3 or later # * A GPU that supports Triton # @@ -156,7 +157,7 @@ def add_fn(x, y): # together, ``triton.heuristics`` must be used first. # # Conclusion -# ------------------------------------------------------------------- +# ----------- # In this recipe, we explored how to utilize user-defined Triton kernels # with ``torch.compile``. We delved into the basic usage of a simple # vector addition kernel and advanced usage involving Triton's autotune @@ -164,7 +165,7 @@ def add_fn(x, y): # kernels with other PyTorch features and highlighted some current limitations. # # See Also -# ------------------------------------------------------------------- +# --------- # # * `Compiling the Optimizers: `__. # * `Implementing High-Performance Transformers with Scaled Dot Product Attention: `__. From 8236f8d00b0b4e5114f920ac65e64a4187c90e09 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 22 Mar 2024 12:41:02 -0700 Subject: [PATCH 09/12] Update torch_compile_user_defined_triton_kernel_tutorial.py Minor formatting fix --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index ab67b356448..0aa44752850 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -150,11 +150,11 @@ def add_fn(x, y): # However, there are certain limitations to be aware of: # # * **Tensor Subclasses:** Currently, there is no support for -# tensor subclasses and other advanced features. +# tensor subclasses and other advanced features. # * **Triton Features:** While ``triton.heuristics`` can be used either standalone or -# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This -# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used -# together, ``triton.heuristics`` must be used first. +# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This +# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used +# together, ``triton.heuristics`` must be used first. # # Conclusion # ----------- From 68be9d9267c9681692992ecdc463971a8a66c709 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 19 Apr 2024 11:06:40 -0700 Subject: [PATCH 10/12] Apply suggestions from code review Editorial fixes --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 0aa44752850..ac23ca42de7 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -167,5 +167,5 @@ def add_fn(x, y): # See Also # --------- # -# * `Compiling the Optimizers: `__. -# * `Implementing High-Performance Transformers with Scaled Dot Product Attention: `__. +# * `Compiling the Optimizers: `__ +# * `Implementing High-Performance Transformers with Scaled Dot Product Attention`__ From 2a872687653a9ba6d82877b89bfcb704c9423535 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 19 Apr 2024 13:29:58 -0700 Subject: [PATCH 11/12] Update metadata.json --- .jenkins/metadata.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index 33ea8ae72da..f3d03f3cf1c 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -29,7 +29,7 @@ "needs": "linux.16xlarge.nvidia.gpu" }, "intermediate_source/torchvision_tutorial.py": { - "needs": "linux.g5.4xlarge.nvidia.gpu", + "needs": "linux.16xlarge.nvidia.gpu", "_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py." }, "advanced_source/coding_ddpg.py": { From 8ee52f752f77f58ab9ea2094757ea7da64de7f50 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 19 Apr 2024 15:46:54 -0700 Subject: [PATCH 12/12] Update .jenkins/metadata.json --- .jenkins/metadata.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index f3d03f3cf1c..33ea8ae72da 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -29,7 +29,7 @@ "needs": "linux.16xlarge.nvidia.gpu" }, "intermediate_source/torchvision_tutorial.py": { - "needs": "linux.16xlarge.nvidia.gpu", + "needs": "linux.g5.4xlarge.nvidia.gpu", "_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py." }, "advanced_source/coding_ddpg.py": {