From b93af8ec8ad42fd108e5a0adfc849e00835f5949 Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Tue, 1 Apr 2025 08:36:39 +0100 Subject: [PATCH 1/9] Initial implementation of XPU inductor codegen logic through SYCL Cutlass for gemm operator --- torch/_inductor/autotune_process.py | 89 ++ torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/xpu/cutlass_utils.py | 250 +++++ torch/_inductor/codegen/xpu/gemm_template.py | 915 ++++++++++++++++++ .../codegen/xpu/sycl_cpp_scheduling.py | 116 +++ torch/_inductor/codegen/xpu/sycl_kernel.py | 517 ++++++++++ torch/_inductor/codegen/xpu/sycl_template.py | 274 ++++++ .../codegen/xpu_combined_scheduling.py | 127 +++ torch/_inductor/config.py | 14 + torch/_inductor/ir.py | 21 +- torch/_inductor/kernel/mm.py | 5 + torch/_inductor/scheduler.py | 4 +- torch/_inductor/utils.py | 24 + 13 files changed, 2357 insertions(+), 3 deletions(-) create mode 100644 torch/_inductor/codegen/xpu/cutlass_utils.py create mode 100644 torch/_inductor/codegen/xpu/gemm_template.py create mode 100644 torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py create mode 100644 torch/_inductor/codegen/xpu/sycl_kernel.py create mode 100644 torch/_inductor/codegen/xpu/sycl_template.py create mode 100644 torch/_inductor/codegen/xpu_combined_scheduling.py diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 5b0369ab98e36..8bafb009b12f5 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -24,6 +24,7 @@ from torch._inductor.codecache import ( CppCodeCache, CUDACodeCache, + SYCLCodeCache, DLLWrapper, get_hash, PyCodeCache, @@ -934,6 +935,94 @@ def cleanup_run_fn(self) -> None: def __str__(self) -> str: return f"{self.kernel_name=}" +class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + # TODO (SYCL) : Complete the bmrq class to enable full autotuning + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = SYCLCodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate SYCLCodeCache + autotuning_log.debug("Precompiling %s", self) + SYCLCodeCache.compile(self.source_code, "so") + autotuning_log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [ + c_void_p(tensor.data_ptr()) + for tensor in list(input_tensors) + [output_tensor] + ] + autotuning_log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + queue_ptr = c_void_p(torch.xpu.current_stream().sycl_queue) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=output_tensor.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + queue_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + # Harcoded temporarily for testing with known kernels + self.workspace_size = 4096 # Fixed size for PoC + self._workspace_size_updated = True + # TODO (SYCL) : Implement comprehensive workspace updating mechanism + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = SYCLCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" def benchmark_in_sub_process( choices: list[TritonTemplateCaller], diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 152d2ef36197f..33d2cc4554eab 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -390,6 +390,7 @@ def init_backend_registration() -> None: from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef from .cpp_wrapper_gpu import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling + from .xpu_combined_scheduling import SYCLCombinedScheduling from .halide import HalideScheduling from .mps import MetalScheduling from .triton import TritonScheduling @@ -424,9 +425,10 @@ def init_backend_registration() -> None: ) if get_scheduling_for_device("xpu") is None: + # SYCLCombinedScheduling combines Triton and SYCL C++ scheduling for XPU devices via delegation register_backend_for_device( "xpu", - TritonScheduling, + SYCLCombinedScheduling, PythonWrapperCodegen, CppWrapperGpu, ) diff --git a/torch/_inductor/codegen/xpu/cutlass_utils.py b/torch/_inductor/codegen/xpu/cutlass_utils.py new file mode 100644 index 0000000000000..a8656425ca7d8 --- /dev/null +++ b/torch/_inductor/codegen/xpu/cutlass_utils.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +import functools +import logging +import os +import sys +from dataclasses import dataclass +from typing import Any, Optional + +import sympy + +import torch +from torch._inductor.utils import clear_on_fresh_inductor_cache + +from ... import config +from ...ir import Layout +from ...runtime.runtime_utils import cache_dir +from ...virtualized import V + + +log = logging.getLogger(__name__) + + +@functools.lru_cache(None) +def try_import_cutlass() -> bool: + """ + Currently only supporting user specified cutlass_dir or falling to the + default ../third_party/cutlass/ (build from source setups). + """ + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + + cutlass_py_full_path = os.path.abspath( + os.path.join(config.sycl.cutlass_dir, "python/cutlass_library") + ) + tmp_cutlass_py_full_path = os.path.abspath( + os.path.join(cache_dir(), "torch_cutlass_library") + ) + dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library") + + if os.path.isdir(cutlass_py_full_path): + if tmp_cutlass_py_full_path not in sys.path: + if os.path.exists(dst_link): + assert os.path.islink( + dst_link + ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + cutlass_py_full_path + ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" + else: + os.makedirs(tmp_cutlass_py_full_path, exist_ok=True) + os.symlink(cutlass_py_full_path, dst_link) + sys.path.append(tmp_cutlass_py_full_path) + try: + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + + return True + except ImportError as e: + log.debug( + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_py_full_path, + ) + return False + + +@functools.lru_cache(8) +def _normalize_sycl_arch(arch: str) -> str: + if int(arch) == 11: + return "11" + else: + raise NotImplementedError(f"Unsupported sycl arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures = "11" + # instantiation_level: Optional[str] = None + + operations = "all" + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + exclude_kernels = "" + # UNUSED at the moment, part of Manifest class in cutlass_library + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = False + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None: + raise RuntimeError(f"{self.architectures=} is None!") + self.architectures = _normalize_sycl_arch(self.architectures) + + +@clear_on_fresh_inductor_cache +@functools.lru_cache(None) +def _gen_ops_cached(arch) -> list[Any]: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None: + log.error( + "Cannot detect XPU arch %s." + "Will discard all cutlass ops. " + "Please consider setting _inductor.xpu.arch", + arch, + ) + return [] + arch = _normalize_sycl_arch(arch) + + args = CUTLASSArgs( + architectures=arch, + ) + manifest = cutlass_manifest.Manifest(args) + + sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC + + if arch == "11": + cutlass_generator.GeneratePVC(manifest, sycl_version) + else: + log.error("Invalid XPU arch") + return [] + return manifest.operations + + +def gen_ops() -> list[Any]: + """ + Generates all supported CUTLASS operations. + """ + # Currently limited to PVC (arch 1100), harcoding arch + # TODO :(SYCL) get_xpu_arch() + arch = "11" + return _gen_ops_cached(arch) + + +def torch_dtype_to_cutlass_type( + torch_dtype: torch.dtype, +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library # type: ignore[import] + + if torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.float: + return cutlass_library.library.DataType.f32 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype}") + + +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.float: + return cutlass_dtype == cutlass_library.library.DataType.f32 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: list[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + # TODO (SYCL) : Extend this once other accumulator & input types are supported + if len(input_torch_dtypes) != 2: + return None + + if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes): + return torch.float + else: + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}") + + +def get_alignments(torch_dtype: torch.dtype) -> list[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + """ + # TODO (SYCL): Extend for other types & double-check alignments + if torch_dtype == torch.bfloat16: + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + def a_factor_of(x, alignment): + if is_static_int(x) and is_static_int(alignment): + return x % alignment == 0 + rem = sympy.Mod(x, alignment) + return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0)) + + try: + contiguous_dim = inductor_layout.stride.index(1) + except ValueError: + # No dim with stride 1 found, return 1 + return 1 + alignments = get_alignments(dtype) + for alignment in alignments: + if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( + offset, alignment + ): + continue + if all( + (dim == contiguous_dim) + or a_factor_of(inductor_layout.stride[dim], alignment) + for dim in range(len(size)) + ): + return alignment + return 1 + + +# TODO (SYCL) : Add helpers for CUTLASS kernels testing & benchmarking once standalone +# runner is enabled. diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py new file mode 100644 index 0000000000000..d40520b242b3b --- /dev/null +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -0,0 +1,915 @@ +# mypy: allow-untyped-defs +import copy +import enum +import logging +import re +from abc import ABC, abstractmethod +from typing import Optional, Union + +from ... import ir +from ...config import sycl as inductor_sycl_config +from ...ir import ( + Buffer, + ChoiceCaller, + FixedLayout, + IRNode, + Layout, + SYCLTemplateBuffer, +) +from ...virtualized import V +from ..common import IndentedBuffer +from . import cutlass_utils +from .sycl_kernel import SYCLTemplateKernel +from .sycl_template import CUTLASSTemplate + + +log = logging.getLogger(__name__) + +# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below. +GEMM_TEMPLATE_CUTLASS_3X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + + const int device_id = 0; + + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(device_id); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, X, W, Bias, + Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op.run(); + CUTLASS_CHECK(status); + syclcompat::wait(); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} + +// configuration name: {{op_conf_name}} +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}}, + hw_info + }; +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + + +class CUTLASSGemmTemplate(CUTLASSTemplate, ABC): + """ + CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + Args: + input_nodes (List[Buffer]): List of input nodes of the GEMM kernel. + layout (Layout): Layout type of the resulting output node. + alpha (float): The scaling factor for the product of the inputs in the GEMM operation. + beta (float): The scaling factor applied to the output matrix. + input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided, + no reordering is performed. Defaults to None. + """ + super().__init__("cutlass_gemm", input_nodes, layout, input_reorder) + self.alpha = alpha + self.beta = beta + assert len(input_nodes) == 2 or len(input_nodes) == 3 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + + @staticmethod + @abstractmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError + + @abstractmethod + def _get_template(self) -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + raise NotImplementedError + + def _add_cutlass_gemm_choices( + self, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Cutlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + + ops = self.gen_ops() + for name, op in ops: + for swizzle in inductor_sycl_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, description=description, op=op, swizzle=swizzle + ) + break # TODO (SYCL) : Currently limited to one config + break # TODO (SYCL) : Currently limited to one config + + if len(ops) == 0: + input_layouts = [node.get_layout() for node in input_nodes] + input_strides = [node.get_stride() for node in input_nodes] + output_layout = layout + warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 + log.warning(warning_msg) + log.debug( + "Added %d Cutlass gemm configs.", + len(ops), + ) + + def header(self) -> IndentedBuffer: + """ + Returns a buffer containing SYCL C++ code for the header section of the CUTLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated SYCL C++ header code. + """ + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/epilogue/thread/activation.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + return res + + @staticmethod + def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return cutlass_lib.LayoutType.RowMajor + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1): + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + """Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor + to ColumnMajor or vice versa""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + def layout_match( + torch_layout: ir.Layout, + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Cutlass layout""" + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + """ + Helper method to update the alignment of a given CUTLASS GEMM op operand's element. + + This method modifies the alignment of the given Cutlass GEMM op operand's element to match the + layout of the corresponding ir.Buffer node. + + Args: + torch_layout: The layout of the corresponding ir.Buffer node. + op_element: The Cutlass GEMM op operand's element whose alignment is to be updated. + + Returns: + bool: True if the alignment was successfully updated, False otherwise. + """ + alignment = cutlass_utils.get_max_alignment(torch_layout) + op_element.alignment = alignment + return True + + def _dtype_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """ + Checking dtypes of A, B, acc, D here. + + Empirically speaking, CUTLASS2x ops have same dtype for C and D. + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.D.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return False + + return True + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Cutlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + # TODO (SYCL) : Extend this to account for other parameters such as regex filtering etc.. + + assert cutlass_utils.try_import_cutlass() + + if op.gemm_kind not in self._get_supported_ops(): + return None + + # Filter ops by dtypes. + if not self._dtype_match(op): + return None + + # Layout compatibility check + X = self.input_nodes[0] + W = self.input_nodes[1] + + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Update Op + op = copy.deepcopy(op) + + # Set output layout + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Set alignments - crucial for performance and correctness + status = ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ) + if not status: + return None + + # Set epilogue accumulator type + op.element_epilogue = op.accumulator_type() + + # Handle bias if present (for linear combination epilogue) + status = self._set_bias_layout_and_alignment(op) + if not status: + return None + + return op + + def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[Tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm] + res: dict[str, cutlass_gemm_op.GemmOperation] = {} + for op_dict in ops.values(): + for op_list in op_dict.values(): + for op in op_list: + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and filter_res.configuration_name() != op.configuration_name() + ): + log.debug( + "Detected change in configuration name. Original " + "name: %s, filtered configuration name: %s", + op.configuration_name(), + filter_res.configuration_name(), + ) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + log.info("Got cutlass configs: total number of ops: %d, ", len(res)) + sorted_res = sorted(res.items()) + return sorted_res[: inductor_sycl_config.cutlass_max_profiling_configs] + + def gemm_mode(self) -> str: + """ + Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements + a batched GEMM or a simple GEMM without batch dimension. + + Returns: + str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions, + "cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise + "cutlass::gemm::GemmUniversalMode::kGemm" is returned. + """ + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render( # type: ignore[override] + self, + kernel: SYCLTemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[SYCLTemplateBuffer] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Cutlass based SYCL C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (SYCLTemplateKernel): The kernel to be rendered. + op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Cutlass based SYCL C++ code fragment as a string, to be used by the current + SYCLTemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance( + op, cutlass_gemm_op.GemmOperation + ), "op argument is required and has to be an instance of GemmOperation" + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + if not isinstance(X.layout, FixedLayout): + raise NotImplementedError("X.layout is not fixed") + if not isinstance(W.layout, FixedLayout): + raise NotImplementedError("W.layout is not fixed") + + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op) + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + kernel_call_signature = kernel.def_kernel( + inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type] + ) + + # Make op mutable without affecting others + op = copy.deepcopy(op) + if Bias is not None: + assert Bias.get_layout().dtype == X.get_layout().dtype + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + argument_template, epilogue_template = self._get_template_args(op) + + epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + + instance_definition, instance_type = self._define_gemm_instance(op) + + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + kernel_call_signature=kernel_call_signature, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=epilogue_args, + test_call_statement="", # TODO (SYCL) : Enable once Standalone runner is implemented + op_conf_name=op.configuration_name(), + ) + options.update(dict(zip(extra_names, extra_inputs))) + + return self._template_from_string(self._get_template()).render(**options) + + +class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS3xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal3x] + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_3X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", + ) -> tuple[str, Optional[str]]: + return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM). + + This function checks compatibility of A, B, and possibly C operand layouts for + a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'. + It verifies requirements such as matching data types, minimum rank, and suitability + for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`, + `addmm`, `bmm`, `baddbmm`, etc. + + Args: + layouts (List[Layout]): List containing 2 or 3 Layout objects representing + the input matrices A, B, and possibly C. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = list(V.graph.sizevars.size_hints(A_layout.size)) + B_size = list(V.graph.sizevars.size_hints(B_layout.size)) + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [int(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None + if has_bias: + bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(bias.get_layout()) + op.C.layout = bias_layout + status = self.set_alignment(bias.get_layout(), op.C) + if not status: + return False + return True + + def _dtype_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """ + Checking dtypes of C (i.e. bias) here, since that is the one not checked in the base class. + """ + + if not super()._dtype_match(op): + return False + + assert cutlass_utils.try_import_cutlass() + from cutlass_library.library import DataType # type: ignore[import] + + has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None + + if op.C.element == DataType.void: + if has_bias: + # op expects no bias, but bias exists + return False + else: + # op expects bias. Needs to check if bias exists and is of the right dtype + if not ( + has_bias + and cutlass_utils.dtype_match( + self.input_nodes[2].get_dtype(), op.C.element + ) + ): + return False + + return True + + def _define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, str]: + """Defines and renders the Cutlass / SYCL C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance() + if not hasattr(op, "epilogue_functor") or not isinstance( + op.epilogue_functor, enum.Enum + ): + op = copy.deepcopy(op) + op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + inputs: list[Optional[Buffer]] = [] + names: list[str] = [] + return (Bias, inputs, names) + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: SYCLTemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass SYCL C++ code required for passing arguments to the GEMM operation. + + Args: + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (SYCLTemplateKernel): SYCL Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of SYCL C++ code as a string, ready to be used as arguments for the GEMM operation. + + """ + options = dict( + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + assert epilogue_template is not None + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments diff --git a/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py b/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py new file mode 100644 index 0000000000000..fd100347a7682 --- /dev/null +++ b/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py @@ -0,0 +1,116 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import cast + +from torch.utils._ordered_set import OrderedSet + +from ...._dynamo.utils import counters +from ... import config +from ...codecache import code_hash, get_path +from ...ir import SYCLTemplateBuffer +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class SYCLCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for SYCL C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by SYCLCombinedScheduling. + + It handles fusion decisions and SYCL C++ specific template code generation. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_sycl_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, SYCLTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["sycl", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.sycl(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a SYCL template, possibly with fused epilogues + """ + counters["inductor"]["sycl_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_sycl_cpp_template( + template_node + ), "Template node passed to SYCLScheduler.codegen_template must be a SchedulerNode that wraps a SYCLTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: SYCLTemplateBuffer = cast(SYCLTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/torch/_inductor/codegen/xpu/sycl_kernel.py b/torch/_inductor/codegen/xpu/sycl_kernel.py new file mode 100644 index 0000000000000..e17f1835e173b --- /dev/null +++ b/torch/_inductor/codegen/xpu/sycl_kernel.py @@ -0,0 +1,517 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union + +from sympy import Expr, symbols + +from torch import dtype as torch_dtype + + +if TYPE_CHECKING: + from .sycl_template import ArgInfo + +from ...autotune_process import SYCLBenchmarkRequest +from ...ir import ( + Buffer, + ChoiceCaller, + IRNode, + Layout, + PrimitiveInfoType, + SYCLTemplateBuffer, + TensorBox, +) +from ...utils import sympy_product +from ...virtualized import V +from ..common import ( + IndentedBuffer, + Kernel, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from ..cpp_utils import DTYPE_TO_CPP, CppPrinter + + +if TYPE_CHECKING: + from torch._inductor.codegen.xpu.sycl_template import SYCLTemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +ValidLayoutSymbols = Literal["M", "N", "K", "lda", "ldb", "ldc", "ldd"] +ValidLayoutAttrs = Literal["size", "stride"] + + +@dataclass(frozen=True) +class LayoutArg: + node: IRNode + symbol: ValidLayoutSymbols + attr: ValidLayoutAttrs + dim: int + + def matches(self, node, attr, dim) -> bool: + return self.node == node and self.attr == attr and self.dim == dim + + +class SYCLKernel(Kernel): + """ + Baseclass for SYCL / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layout_args: dict[str, LayoutArg] = {} + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + + def find_symbol( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[str]: + arg = self.find_layout_arg(node, attr, dim) + return arg.symbol if arg else None + + def find_layout_arg( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[LayoutArg]: + matches = [ + arg for arg in self.layout_args.values() if arg.matches(node, attr, dim) + ] + assert len(matches) <= 1, matches + return None if len(matches) == 0 else matches[0] + + def add_layout_arg( + self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int + ): + arg = LayoutArg(node, symbol, attr, dim) + self.layout_args.setdefault(symbol, arg) + + def init_layout_args(self) -> None: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + self.add_layout_arg("M", X, "size", mdim) + self.add_layout_arg("N", W, "size", ndim) + self.add_layout_arg("K", X, "size", kdim) + + lda_dim = self.find_ld_idx(X) + ldb_dim = self.find_ld_idx(W) + ldc_dim = self.find_ld_idx(Bias) if Bias else None + ldd_dim = self.find_ld_idx(Y) + self.add_layout_arg("lda", X, "stride", lda_dim) + self.add_layout_arg("ldb", W, "stride", ldb_dim) + if Bias is not None and ldc_dim is not None: + self.add_layout_arg("ldc", Bias, "stride", ldc_dim) + self.add_layout_arg("ldd", Y, "stride", ldd_dim) + + def get_layout_args(self) -> tuple[Union[Expr, int], ...]: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + + def get_ld(node) -> Union[Expr, int]: + dim = self.find_ld_idx(node) + return node.get_stride()[dim] + + M = X.get_size()[mdim] + N = W.get_size()[ndim] + K = X.get_size()[kdim] + LDA = get_ld(X) + LDB = get_ld(W) + LDC = get_ld(Bias) if Bias else 0 + LDD = get_ld(Y) + return (M, N, K, LDA, LDB, LDC, LDD) + + @staticmethod + def find_ld_idx(node: IRNode) -> int: + strides = node.get_stride() + # Handle 1D tensor case + if V.graph.sizevars.statically_known_equals(strides[-1], 1): + return _normalize_idx(-2, len(strides)) + + assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2] + return _normalize_idx(-1, len(strides)) + + +class SYCLTemplateKernel(SYCLKernel): + """ + Template kernels defined by SYCL / Cutlass in C++. + """ + + # TODO (SYCL): The SYCL queue is not being used + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, sycl::queue stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the SYCLTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def get_signature(self) -> str: + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + + self.init_layout_args() + size_args = [ + f"const int {s}" for s in ("M", "N", "K", "lda", "ldb", "ldc", "ldd") + ] + + runtime_arg_decls = ",".join( + [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + ) + if runtime_arg_decls: + runtime_arg_decls += ", " + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "SYCLTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The SYCLTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + # TODO (SYCL) : Extend support to cpp_wrapper once AOTInductor is supported + + arg_types: list[Any] + _, call_args, _, arg_types = self.args.python_argdefs() + + layout_args = self.get_layout_args() + call_args.extend(layout_args) # type: ignore[arg-type] + for arg in self.runtime_arg_values: + call_args.append(arg) + arg_types.extend("int" for a in layout_args) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + elif isinstance(arg_types[i], torch_dtype): + call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + # workspace_size is here. + call_args.append("None") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + workspace = str(ws.outer_name) + call_args.append(f"c_void_p({workspace}.data_ptr())") + else: + ws = None + call_args.append("None") + + wrapper.generate_kernel_call( + name, + call_args, + gpu=True, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default_dtype + from torch._inductor.codegen.xpu.sycl_template import CUTLASSTemplate + + return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) # type: ignore[union-attr] + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + sizes = [ + self.find_symbol(node, "size", dim=i) or node.get_size()[i] + for i in range(start_index, end_index + 1) + ] + if len(sizes) == 0: + return str(default_value) + + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] + val = sympy_product(sizes) + return val + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + if V.graph.sizevars.statically_known_leq(stride, 1): + return str(stride) + return self.find_symbol(node, "stride", dim=index) or str(stride) + + +class SYCLTemplateCaller(ChoiceCaller): + """ + SYCLTemplateCaller + + This class represents a caller for SYCL template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (SYCLBenchmarkRequest): The benchmark request for the caller (currently incomplete). + template_buffer (SYCLTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[[SYCLTemplateBuffer, Optional[list[IRNode]]], str], + bmreq: SYCLBenchmarkRequest, + template: "SYCLTemplate", # type: ignore[name-defined] + info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg] + description: str, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + # TODO (SYCL) : Enable benchmarking once supported + return 0.001 + + def __str__(self) -> str: + return f"SYCLTemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"sycl_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "XPU", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + } + else: + return {"backend": "XPU", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + return TensorBox.create( + SYCLTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/torch/_inductor/codegen/xpu/sycl_template.py b/torch/_inductor/codegen/xpu/sycl_template.py new file mode 100644 index 0000000000000..16cd98ff633d6 --- /dev/null +++ b/torch/_inductor/codegen/xpu/sycl_template.py @@ -0,0 +1,274 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from dataclasses import dataclass +from typing import Any, Optional +from typing_extensions import override +from unittest.mock import patch + +import sympy + +import torch +from torch._logging import getArtifactLogger + +from ...autotune_process import SYCLBenchmarkRequest, TensorMeta +from ...ir import Buffer, IRNode, Layout, SYCLTemplateBuffer +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .sycl_kernel import SYCLTemplateCaller, SYCLTemplateKernel + + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class SYCLTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for SYCL C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the SYCLTemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + description, + **kwargs, + ) -> SYCLTemplateCaller: + """ + Generates the SYCL template caller object for the given GEMM template and operation. This SYCLTemplateCaller + may be used to call and benchmark the generated SYCL kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A SYCLTemplateCaller object representing the generated SYCL template caller. + """ + kernel_name = f"xpu_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), SYCLTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + autotuning_log.debug("Generated Code:\n%s", code) + autotuning_log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) + size_args = V.graph.sizevars.size_hints(kernel.get_layout_args()) + extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + + kernel_hash_name = f"xpu_{self.name}_{next(self.index_counter)}" + + # create the BenchmarkRequest + bmreq = SYCLBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: SYCLTemplateBuffer, + epilogue_nodes: Optional[list[IRNode]] = None, + ): + kernel = SYCLTemplateKernel( + kernel_name="KERNEL_NAME", + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return SYCLTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + description, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] + + +class CUTLASSTemplate(SYCLTemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in ("1", "1L"): + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("swizzle", "const uint8_t")] + + @override + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/torch/_inductor/codegen/xpu_combined_scheduling.py b/torch/_inductor/codegen/xpu_combined_scheduling.py new file mode 100644 index 0000000000000..ece79f312b4c4 --- /dev/null +++ b/torch/_inductor/codegen/xpu_combined_scheduling.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING, Union + +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from .xpu.sycl_cpp_scheduling import SYCLCPPScheduling +from .triton import TritonScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing_extensions import TypeAlias + + from sympy import Expr + + import torch + from torch.utils._ordered_set import OrderedSet + + from .common import BackendFeature + + _IntLike: TypeAlias = Union[int, Expr] + + +class SYCLCombinedScheduling(BaseScheduling): + """ + Scheduler for SYCL Kernels, which delegates calls as appropriate + to the SYCL-C++ (CUTLASS) and Triton Schedulers, which both work for XPU devices + and use a unified-wrapper for codegen. + """ + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + self._triton_scheduling = TritonScheduling(scheduler) + self._sycl_cpp_scheduling = SYCLCPPScheduling(scheduler) + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + return TritonScheduling.get_backend_features(device) + + @classmethod + def raise_if_unavailable( + cls, device: Union[str, torch.device, None] = None + ) -> None: + TritonScheduling.raise_if_unavailable(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._sycl_cpp_scheduling.is_sycl_cpp_template(node): + return self._sycl_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self._sycl_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + for node in (node1, node2): + if self._sycl_cpp_scheduling.is_sycl_cpp_template(node): + return self._sycl_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[_IntLike]] + ) -> tuple[tuple[_IntLike, ...], ...]: + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + if self._sycl_cpp_scheduling.is_sycl_cpp_template(template_node): + assert not epilogue_nodes + assert not prologue_nodes + return self._sycl_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self) -> None: + return self._triton_scheduling.codegen_sync() + + def flush(self) -> None: + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None: + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def benchmark_codegened_module(self, module): + return self._triton_scheduling.benchmark_codegened_module(module) + + def generate_kernel_code_from_nodes( + self, nodes: Sequence[Any], benchmark_kernel: bool = False + ) -> str: + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel + ) + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index dac8fc186c85c..d501ca39c6e1e 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1417,6 +1417,20 @@ class sycl: # Whether to use fast math. use_fast_math = False + + # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. + cutlass_backend_min_gemm_size: int = 1 + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + cutlass_max_profiling_configs: Optional[int] = None + + # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. + cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4] + + # TODO (SYCL) : Enable the standalone GEMM runner for testing later + generate_test_runner: bool = False + # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8d9bb1b283cdc..e3886c4751196 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -101,11 +101,14 @@ from torch.fx.node import Node from .codegen.cuda.cuda_template import CUDATemplate + from .codegen.xpu.sycl_template import SYCLTemplate + from .graph import GraphLowering from .utils import IndentedBuffer else: CUDATemplate: TypeAlias = object + SYCLTemplate: TypeAlias = object try: @@ -4535,7 +4538,7 @@ class ChoiceCaller: During autotuning, self.benchmark() is first called to get benchmark result, and if this choice is selected, self.output_node() is called to get the output_node. - Children classes: TritonTemplateCaller, CUDATemplateCaller. + Children classes: TritonTemplateCaller, CUDATemplateCaller, SYCLTemplateCaller. """ def __init__( @@ -4670,6 +4673,22 @@ def __init__( # type: ignore[no-untyped-def] def get_workspace_size(self): # type: ignore[no-untyped-def] return self.workspace_size if self.workspace_size is not None else 0 +class SYCLTemplateBuffer(TemplateBuffer): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: SYCLTemplate, + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): # type: ignore[no-untyped-def] + return self.workspace_size if self.workspace_size is not None else 0 class CppTemplateBuffer(TemplateBuffer): def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def] diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 2bd8cbe4dbb3a..e5caf6ea52e35 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -19,6 +19,7 @@ from .. import config as inductor_config, ir from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from ..codegen.xpu.gemm_template import CUTLASS3xGemmTemplate as SYCLCUTLASS3xGemmTemplate from ..codegen.wrapper import PythonWrapperCodegen from ..ir import FlexibleLayout, is_triton from ..lowering import register_lowering @@ -34,6 +35,7 @@ use_ck_gemm_template, use_cpp_gemm_template, use_cutlass_template, + use_cutlass_sycl_template, use_max_autotune, use_triton_template, use_triton_tma_template, @@ -419,6 +421,9 @@ def tuned_mm(mat1, mat2, *, layout=None): if is_nonzero and use_cutlass_template(layout, m, n, k): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + if is_nonzero and use_cutlass_sycl_template(layout, m, n, k): + SYCLCUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 9266588833257..64c014342292b 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -4300,9 +4300,11 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: node = typing.cast(ForeachKernelSchedulerNode, node) backend_ = self.get_backend(device) from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + from .codegen.xpu_combined_scheduling import SYCLCombinedScheduling + from .codegen.simd import SIMDScheduling - if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling, SYCLCombinedScheduling)): backend = backend_ else: raise AssertionError(f"{type(self)=}") diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index e8a800b5487a5..9bf0f48024904 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1411,6 +1411,30 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: return False return res +def use_cutlass_sycl_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0 or gemm_size < config.sycl.cutlass_backend_min_gemm_size: + return False + from .codegen.xpu.cutlass_utils import try_import_cutlass + + layout_dtypes = [torch.bfloat16] # TODO (SYCL) : Extend to the rest of dtypes + res = ( + _use_template_for_gpu(layout, layout_dtypes) + and use_max_autotune() + and _use_autotune_backend("CUTLASS") + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." + ) + return False + return res @functools.lru_cache(None) def _rocm_native_device_arch_name(device: str) -> str: From 5bdeddddedc06cc7dc3f3baac947bf5c3604948e Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Wed, 2 Apr 2025 07:03:20 +0100 Subject: [PATCH 2/9] Added xpu guard on cuda cutlass selection for mm --- torch/_inductor/codegen/xpu/gemm_template.py | 8 ++++---- torch/_inductor/utils.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index d40520b242b3b..fddf6d3ad94a7 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -39,7 +39,7 @@ using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; using coord_t = cutlass::gemm::GemmCoord::Index; static cutlass::KernelHardwareInfo hw_info; - + const int device_id = 0; if (hw_info.sm_count == 0) { @@ -563,7 +563,7 @@ def render( # type: ignore[override] assert cutlass_utils.try_import_cutlass() import cutlass_library.gemm_operation as cutlass_gemm_op - import cutlass_library.library as cutlass_lib + import cutlass_library.library as cutlass_lib # noqa: F401 assert isinstance( op, cutlass_gemm_op.GemmOperation @@ -663,7 +663,7 @@ def add_cutlass_gemm_choices( ) @staticmethod - def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 import cutlass_library.library as cutlass_lib return [cutlass_lib.GemmKind.Universal3x] @@ -673,7 +673,7 @@ def _get_template(self) -> str: def _get_template_args( self, - op: "cutlass_library.gemm_op.GemmOperation", + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 ) -> tuple[str, Optional[str]]: return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 9bf0f48024904..dc9bbad09716f 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1269,6 +1269,14 @@ def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: return False return True +@functools.lru_cache(None) +def _is_xpu(index_or_device: Union[int, torch.device] = 0) -> bool: + if isinstance(index_or_device, torch.device): + device = index_or_device + else: + device = torch.device(get_gpu_type(), index_or_device) + + return device.type == "xpu" @functools.lru_cache def get_max_num_sms() -> int: @@ -1390,8 +1398,8 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: return False from .codegen.cuda.cutlass_utils import try_import_cutlass - # Do not use cutlass template on ROCm - if torch.version.hip: + # Do not use CUDA cutlass template on ROCm or SYCL + if torch.version.hip or _is_xpu(layout): return False layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] @@ -1419,6 +1427,9 @@ def use_cutlass_sycl_template(layout: Layout, m: int, n: int, k: int) -> bool: return False from .codegen.xpu.cutlass_utils import try_import_cutlass + if not _is_xpu(layout): + return False + layout_dtypes = [torch.bfloat16] # TODO (SYCL) : Extend to the rest of dtypes res = ( _use_template_for_gpu(layout, layout_dtypes) From 7a447d7036f9fa5acefb44a356d2c457def5ffdf Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 4 Apr 2025 09:01:15 +0100 Subject: [PATCH 3/9] Fixed some bugs --- torch/_inductor/codegen/xpu/cutlass_utils.py | 12 ++++++++---- torch/_inductor/codegen/xpu/sycl_kernel.py | 1 - torch/_inductor/utils.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/codegen/xpu/cutlass_utils.py b/torch/_inductor/codegen/xpu/cutlass_utils.py index a8656425ca7d8..bde570263ece5 100644 --- a/torch/_inductor/codegen/xpu/cutlass_utils.py +++ b/torch/_inductor/codegen/xpu/cutlass_utils.py @@ -29,7 +29,7 @@ def try_import_cutlass() -> bool: # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. cutlass_py_full_path = os.path.abspath( - os.path.join(config.sycl.cutlass_dir, "python/cutlass_library") + os.path.join(config.cutlass_dir, "python/cutlass_library") ) tmp_cutlass_py_full_path = os.path.abspath( os.path.join(cache_dir(), "torch_cutlass_library") @@ -82,8 +82,9 @@ class CUTLASSArgs: CUTLASS args used to initialize a CUTLASS Manifest. """ - architectures = "11" - # instantiation_level: Optional[str] = None + architectures: Optional[str] = None + cuda_version: Optional[str] = None # Unused in generator.py for PVC + instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC operations = "all" build_dir = "" @@ -123,12 +124,15 @@ def _gen_ops_cached(arch) -> list[Any]: return [] arch = _normalize_sycl_arch(arch) + sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC + args = CUTLASSArgs( architectures=arch, + instantiation_level = "0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py + cuda_version = sycl_version, ) manifest = cutlass_manifest.Manifest(args) - sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC if arch == "11": cutlass_generator.GeneratePVC(manifest, sycl_version) diff --git a/torch/_inductor/codegen/xpu/sycl_kernel.py b/torch/_inductor/codegen/xpu/sycl_kernel.py index e17f1835e173b..f3f857f322a36 100644 --- a/torch/_inductor/codegen/xpu/sycl_kernel.py +++ b/torch/_inductor/codegen/xpu/sycl_kernel.py @@ -321,7 +321,6 @@ def call_kernel( wrapper.generate_kernel_call( name, call_args, - gpu=True, triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index dc9bbad09716f..d6b7208416e7f 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1399,7 +1399,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .codegen.cuda.cutlass_utils import try_import_cutlass # Do not use CUDA cutlass template on ROCm or SYCL - if torch.version.hip or _is_xpu(layout): + if torch.version.hip or _is_xpu(layout.device): return False layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] @@ -1427,7 +1427,7 @@ def use_cutlass_sycl_template(layout: Layout, m: int, n: int, k: int) -> bool: return False from .codegen.xpu.cutlass_utils import try_import_cutlass - if not _is_xpu(layout): + if not _is_xpu(layout.device): return False layout_dtypes = [torch.bfloat16] # TODO (SYCL) : Extend to the rest of dtypes From a033420e1f7bb0a8652df0b17854037666897cee Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 4 Apr 2025 16:27:46 +0100 Subject: [PATCH 4/9] Output type workarounds for an initial bfloat16 -> float32 working PoC --- torch/_inductor/codegen/xpu/gemm_template.py | 41 ++++++++++++-------- torch/_inductor/codegen/xpu/sycl_template.py | 1 + torch/_inductor/config.py | 2 +- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index fddf6d3ad94a7..eff20261aec58 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -21,7 +21,7 @@ from . import cutlass_utils from .sycl_kernel import SYCLTemplateKernel from .sycl_template import CUTLASSTemplate - +import torch log = logging.getLogger(__name__) @@ -277,8 +277,6 @@ def _add_cutlass_gemm_choices( self.maybe_append_choice( choices, description=description, op=op, swizzle=swizzle ) - break # TODO (SYCL) : Currently limited to one config - break # TODO (SYCL) : Currently limited to one config if len(ops) == 0: input_layouts = [node.get_layout() for node in input_nodes] @@ -403,6 +401,8 @@ def _dtype_match( if not ( cutlass_utils.dtype_match(X.get_dtype(), op.A.element) and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + # TODO (SYCL) : Careful with this dtypes check of the output, as it would + # return false without the workaround in CUTLASS3xGemmTemplate.__init__() and cutlass_utils.dtype_match( self.output_node.get_layout().dtype, op.D.element ) @@ -643,7 +643,11 @@ def __init__( beta: float, input_reorder: Optional[list[int]] = None, ): - super().__init__(input_nodes, layout, alpha, beta, input_reorder) + # TODO (SYCL) : This is a workaround hardcoding output type (layout) to float32 + # Should be removed once not limited to the bfloat input->float32 accum cutlass configurations + float_layout = copy.deepcopy(layout) + float_layout.dtype = torch.float32 + super().__init__(input_nodes, float_layout, alpha, beta, input_reorder) @staticmethod def add_cutlass_gemm_choices( @@ -793,19 +797,22 @@ def _dtype_match( has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None - if op.C.element == DataType.void: - if has_bias: - # op expects no bias, but bias exists - return False - else: - # op expects bias. Needs to check if bias exists and is of the right dtype - if not ( - has_bias - and cutlass_utils.dtype_match( - self.input_nodes[2].get_dtype(), op.C.element - ) - ): - return False + # TODO (SYCL) : Extend this once more output dtypes are supported, + # AND No source (C) is supported + + # if op.C.element == DataType.void: + # if has_bias: + # # op expects no bias, but bias exists + # return False + # else: + # # op expects bias. Needs to check if bias exists and is of the right dtype + # if not ( + # has_bias + # and cutlass_utils.dtype_match( + # self.input_nodes[2].get_dtype(), op.C.element + # ) + # ): + # return False return True diff --git a/torch/_inductor/codegen/xpu/sycl_template.py b/torch/_inductor/codegen/xpu/sycl_template.py index 16cd98ff633d6..c87775a1387b6 100644 --- a/torch/_inductor/codegen/xpu/sycl_template.py +++ b/torch/_inductor/codegen/xpu/sycl_template.py @@ -214,6 +214,7 @@ def globals(self) -> IndentedBuffer: res.splice( """ using namespace cute; + using bfloat16 = cutlass::bfloat16_t; // TODO (SYCL) Workaround the cpp bfloat16 not mapping to bfloat16_t #define CUTLASS_CHECK(status) \\ { \\ cutlass::Status error = status; \\ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d501ca39c6e1e..f22af5911df7c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1426,7 +1426,7 @@ class sycl: cutlass_max_profiling_configs: Optional[int] = None # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. - cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4] + cutlass_max_profiling_swizzle_options: list[int] = [1] # Currently set to 1 value until benchmarking is supported # TODO (SYCL) : Enable the standalone GEMM runner for testing later generate_test_runner: bool = False From ceb32d20852eb0c57fdb6a431dd84b9811108404 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Sun, 6 Apr 2025 14:36:51 +0100 Subject: [PATCH 5/9] Apply suggestions from code review - Formatting Co-authored-by: Lukas Sommer --- torch/_inductor/codegen/xpu/gemm_template.py | 8 ++++---- torch/_inductor/codegen/xpu/sycl_template.py | 15 ++++++++------- torch/_inductor/config.py | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index eff20261aec58..f87af388d38df 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -452,10 +452,10 @@ def filter_op( ): return None - # Update Op + # Update op op = copy.deepcopy(op) - # Set output layout + # Set output layout. op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) # Set alignments - crucial for performance and correctness @@ -643,7 +643,7 @@ def __init__( beta: float, input_reorder: Optional[list[int]] = None, ): - # TODO (SYCL) : This is a workaround hardcoding output type (layout) to float32 + # TODO (SYCL) : This is a workaround hardcoding output type (layout) to float32 # Should be removed once not limited to the bfloat input->float32 accum cutlass configurations float_layout = copy.deepcopy(layout) float_layout.dtype = torch.float32 @@ -797,7 +797,7 @@ def _dtype_match( has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None - # TODO (SYCL) : Extend this once more output dtypes are supported, + # TODO (SYCL) : Extend this once more output dtypes are supported, # AND No source (C) is supported # if op.C.element == DataType.void: diff --git a/torch/_inductor/codegen/xpu/sycl_template.py b/torch/_inductor/codegen/xpu/sycl_template.py index c87775a1387b6..306c8398aebed 100644 --- a/torch/_inductor/codegen/xpu/sycl_template.py +++ b/torch/_inductor/codegen/xpu/sycl_template.py @@ -71,13 +71,14 @@ def generate( # type: ignore[override] A SYCLTemplateCaller object representing the generated SYCL template caller. """ kernel_name = f"xpu_{self.name}" - with patch.object( - V.graph, "get_dtype", self._fake_get_dtype(self.output_node) - ), SYCLTemplateKernel( - kernel_name=kernel_name, - runtime_arg_info=self.get_runtime_arg_info(), - runtime_arg_values=self.get_runtime_arg_values(**kwargs), - ) as kernel: + with ( + patch.object( V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + SYCLTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel + ): code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() autotuning_log.debug("Generated Code:\n%s", code) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f22af5911df7c..c2c7a32a180c7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1426,7 +1426,7 @@ class sycl: cutlass_max_profiling_configs: Optional[int] = None # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. - cutlass_max_profiling_swizzle_options: list[int] = [1] # Currently set to 1 value until benchmarking is supported + cutlass_max_profiling_swizzle_options: list[int] = [1] # TODO(SYCL): Currently set to 1 value until benchmarking is supported # TODO (SYCL) : Enable the standalone GEMM runner for testing later generate_test_runner: bool = False From e38ce367344f8cb7dd9e57e5366d848e5130938d Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Sun, 6 Apr 2025 15:14:53 +0100 Subject: [PATCH 6/9] Fixed .py files formatting & imports ordering --- torch/_inductor/autotune_process.py | 8 +++++--- torch/_inductor/codegen/common.py | 2 +- torch/_inductor/codegen/xpu/cutlass_utils.py | 17 ++++++++--------- torch/_inductor/codegen/xpu/gemm_template.py | 16 ++++++++++------ .../codegen/xpu/sycl_cpp_scheduling.py | 6 +++--- torch/_inductor/codegen/xpu/sycl_kernel.py | 8 +++++--- torch/_inductor/codegen/xpu/sycl_template.py | 4 ++-- .../codegen/xpu_combined_scheduling.py | 2 +- torch/_inductor/ir.py | 3 ++- torch/_inductor/scheduler.py | 8 +++++--- torch/_inductor/utils.py | 9 +++++++-- 11 files changed, 49 insertions(+), 34 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 8bafb009b12f5..c6b9a42d5589d 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -24,10 +24,10 @@ from torch._inductor.codecache import ( CppCodeCache, CUDACodeCache, - SYCLCodeCache, DLLWrapper, get_hash, PyCodeCache, + SYCLCodeCache, ) from torch._inductor.utils import get_gpu_type, is_gpu from torch._logging import getArtifactLogger @@ -935,6 +935,7 @@ def cleanup_run_fn(self) -> None: def __str__(self) -> str: return f"{self.kernel_name=}" + class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put Tensors in here! @@ -956,13 +957,13 @@ def __init__( self.hash_key: str = "" self.source_file: str = "" self.hash_key, self.source_file = SYCLCodeCache.write(self.source_code, "so") - + def precompile(self): # Prepopulate SYCLCodeCache autotuning_log.debug("Precompiling %s", self) SYCLCodeCache.compile(self.source_code, "so") autotuning_log.debug("Done precompiling %s", self) - + def make_run_fn( self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor ) -> Callable[[], None]: @@ -1024,6 +1025,7 @@ def cleanup_run_fn(self) -> None: def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + def benchmark_in_sub_process( choices: list[TritonTemplateCaller], ) -> dict[TritonTemplateCaller, float]: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 33d2cc4554eab..bf7b69d0c2ef4 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -390,11 +390,11 @@ def init_backend_registration() -> None: from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef from .cpp_wrapper_gpu import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling - from .xpu_combined_scheduling import SYCLCombinedScheduling from .halide import HalideScheduling from .mps import MetalScheduling from .triton import TritonScheduling from .wrapper import PythonWrapperCodegen + from .xpu_combined_scheduling import SYCLCombinedScheduling if get_scheduling_for_device("cpu") is None: cpu_backends = { diff --git a/torch/_inductor/codegen/xpu/cutlass_utils.py b/torch/_inductor/codegen/xpu/cutlass_utils.py index bde570263ece5..95f79d0ee1c70 100644 --- a/torch/_inductor/codegen/xpu/cutlass_utils.py +++ b/torch/_inductor/codegen/xpu/cutlass_utils.py @@ -39,9 +39,9 @@ def try_import_cutlass() -> bool: if os.path.isdir(cutlass_py_full_path): if tmp_cutlass_py_full_path not in sys.path: if os.path.exists(dst_link): - assert os.path.islink( - dst_link - ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + assert os.path.islink(dst_link), ( + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + ) assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( cutlass_py_full_path ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" @@ -83,8 +83,8 @@ class CUTLASSArgs: """ architectures: Optional[str] = None - cuda_version: Optional[str] = None # Unused in generator.py for PVC - instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC + cuda_version: Optional[str] = None # Unused in generator.py for PVC + instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC operations = "all" build_dir = "" @@ -124,16 +124,15 @@ def _gen_ops_cached(arch) -> list[Any]: return [] arch = _normalize_sycl_arch(arch) - sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC + sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC args = CUTLASSArgs( architectures=arch, - instantiation_level = "0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py - cuda_version = sycl_version, + instantiation_level="0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py + cuda_version=sycl_version, ) manifest = cutlass_manifest.Manifest(args) - if arch == "11": cutlass_generator.GeneratePVC(manifest, sycl_version) else: diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index f87af388d38df..71465140773c4 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -6,6 +6,8 @@ from abc import ABC, abstractmethod from typing import Optional, Union +import torch + from ... import ir from ...config import sycl as inductor_sycl_config from ...ir import ( @@ -21,7 +23,7 @@ from . import cutlass_utils from .sycl_kernel import SYCLTemplateKernel from .sycl_template import CUTLASSTemplate -import torch + log = logging.getLogger(__name__) @@ -565,9 +567,9 @@ def render( # type: ignore[override] import cutlass_library.gemm_operation as cutlass_gemm_op import cutlass_library.library as cutlass_lib # noqa: F401 - assert isinstance( - op, cutlass_gemm_op.GemmOperation - ), "op argument is required and has to be an instance of GemmOperation" + assert isinstance(op, cutlass_gemm_op.GemmOperation), ( + "op argument is required and has to be an instance of GemmOperation" + ) assert len(self.input_nodes) >= 2 and self.output_node is not None X, W = self.input_nodes[0], self.input_nodes[1] @@ -593,7 +595,10 @@ def render( # type: ignore[override] else: input_reorder = None kernel_call_signature = kernel.def_kernel( - inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type] + inputs=inputs, # type: ignore[arg-type] + outputs=[Y], + names_str=names_str, + input_reorder=input_reorder, # type: ignore[arg-type] ) # Make op mutable without affecting others @@ -793,7 +798,6 @@ def _dtype_match( return False assert cutlass_utils.try_import_cutlass() - from cutlass_library.library import DataType # type: ignore[import] has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None diff --git a/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py b/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py index fd100347a7682..a4d49c1f1b980 100644 --- a/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py +++ b/torch/_inductor/codegen/xpu/sycl_cpp_scheduling.py @@ -87,9 +87,9 @@ def codegen_template( Codegen a SYCL template, possibly with fused epilogues """ counters["inductor"]["sycl_epilogue_fusion_counter"] += len(epilogue_nodes) - assert self.is_sycl_cpp_template( - template_node - ), "Template node passed to SYCLScheduler.codegen_template must be a SchedulerNode that wraps a SYCLTemplateBuffer" + assert self.is_sycl_cpp_template(template_node), ( + "Template node passed to SYCLScheduler.codegen_template must be a SchedulerNode that wraps a SYCLTemplateBuffer" + ) template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 diff --git a/torch/_inductor/codegen/xpu/sycl_kernel.py b/torch/_inductor/codegen/xpu/sycl_kernel.py index f3f857f322a36..2b93b3e5213ad 100644 --- a/torch/_inductor/codegen/xpu/sycl_kernel.py +++ b/torch/_inductor/codegen/xpu/sycl_kernel.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union from sympy import Expr, symbols @@ -30,7 +30,7 @@ WorkspaceArg, WorkspaceZeroMode, ) -from ..cpp_utils import DTYPE_TO_CPP, CppPrinter +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP if TYPE_CHECKING: @@ -451,7 +451,9 @@ def __init__( make_kernel_render: Callable[[SYCLTemplateBuffer, Optional[list[IRNode]]], str], bmreq: SYCLBenchmarkRequest, template: "SYCLTemplate", # type: ignore[name-defined] - info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] description: str, ) -> None: super().__init__(name, input_nodes, layout, description) diff --git a/torch/_inductor/codegen/xpu/sycl_template.py b/torch/_inductor/codegen/xpu/sycl_template.py index 306c8398aebed..f89e450745eea 100644 --- a/torch/_inductor/codegen/xpu/sycl_template.py +++ b/torch/_inductor/codegen/xpu/sycl_template.py @@ -72,12 +72,12 @@ def generate( # type: ignore[override] """ kernel_name = f"xpu_{self.name}" with ( - patch.object( V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), SYCLTemplateKernel( kernel_name=kernel_name, runtime_arg_info=self.get_runtime_arg_info(), runtime_arg_values=self.get_runtime_arg_values(**kwargs), - ) as kernel + ) as kernel, ): code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() diff --git a/torch/_inductor/codegen/xpu_combined_scheduling.py b/torch/_inductor/codegen/xpu_combined_scheduling.py index ece79f312b4c4..7570f845642e3 100644 --- a/torch/_inductor/codegen/xpu_combined_scheduling.py +++ b/torch/_inductor/codegen/xpu_combined_scheduling.py @@ -10,8 +10,8 @@ Scheduler, SchedulerNode, ) -from .xpu.sycl_cpp_scheduling import SYCLCPPScheduling from .triton import TritonScheduling +from .xpu.sycl_cpp_scheduling import SYCLCPPScheduling if TYPE_CHECKING: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index e3886c4751196..b0169e3767c3a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -102,7 +102,6 @@ from .codegen.cuda.cuda_template import CUDATemplate from .codegen.xpu.sycl_template import SYCLTemplate - from .graph import GraphLowering from .utils import IndentedBuffer @@ -4673,6 +4672,7 @@ def __init__( # type: ignore[no-untyped-def] def get_workspace_size(self): # type: ignore[no-untyped-def] return self.workspace_size if self.workspace_size is not None else 0 + class SYCLTemplateBuffer(TemplateBuffer): def __init__( # type: ignore[no-untyped-def] self, @@ -4690,6 +4690,7 @@ def __init__( # type: ignore[no-untyped-def] def get_workspace_size(self): # type: ignore[no-untyped-def] return self.workspace_size if self.workspace_size is not None else 0 + class CppTemplateBuffer(TemplateBuffer): def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def] super().__init__(layout, inputs, make_kernel_render) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 64c014342292b..430813c0a8d0d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -4300,11 +4300,13 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: node = typing.cast(ForeachKernelSchedulerNode, node) backend_ = self.get_backend(device) from .codegen.cuda_combined_scheduling import CUDACombinedScheduling - from .codegen.xpu_combined_scheduling import SYCLCombinedScheduling - from .codegen.simd import SIMDScheduling + from .codegen.xpu_combined_scheduling import SYCLCombinedScheduling - if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling, SYCLCombinedScheduling)): + if isinstance( + backend_, + (SIMDScheduling, CUDACombinedScheduling, SYCLCombinedScheduling), + ): backend = backend_ else: raise AssertionError(f"{type(self)=}") diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d6b7208416e7f..8dc3e617bcfaf 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1269,6 +1269,7 @@ def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: return False return True + @functools.lru_cache(None) def _is_xpu(index_or_device: Union[int, torch.device] = 0) -> bool: if isinstance(index_or_device, torch.device): @@ -1278,6 +1279,7 @@ def _is_xpu(index_or_device: Union[int, torch.device] = 0) -> bool: return device.type == "xpu" + @functools.lru_cache def get_max_num_sms() -> int: return torch.cuda.get_device_properties("cuda").multi_processor_count @@ -1419,24 +1421,26 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: return False return res + def use_cutlass_sycl_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) if gemm_size <= 0 or gemm_size < config.sycl.cutlass_backend_min_gemm_size: return False - from .codegen.xpu.cutlass_utils import try_import_cutlass if not _is_xpu(layout.device): return False - layout_dtypes = [torch.bfloat16] # TODO (SYCL) : Extend to the rest of dtypes + layout_dtypes = [torch.bfloat16] # TODO (SYCL) : Extend to the rest of dtypes res = ( _use_template_for_gpu(layout, layout_dtypes) and use_max_autotune() and _use_autotune_backend("CUTLASS") ) + from .codegen.xpu.cutlass_utils import try_import_cutlass + if res: if not try_import_cutlass(): log.warning( @@ -1447,6 +1451,7 @@ def use_cutlass_sycl_template(layout: Layout, m: int, n: int, k: int) -> bool: return False return res + @functools.lru_cache(None) def _rocm_native_device_arch_name(device: str) -> str: return torch.cuda.get_device_properties(device).gcnArchName From cc172171ea1bafe2138ea741fe30b00f12609bd8 Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Sun, 6 Apr 2025 16:24:04 +0100 Subject: [PATCH 7/9] Addressed PR reviews --- torch/_inductor/codegen/xpu/cutlass_utils.py | 24 ++++++++++++++------ torch/_inductor/codegen/xpu/gemm_template.py | 11 +++++++-- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/codegen/xpu/cutlass_utils.py b/torch/_inductor/codegen/xpu/cutlass_utils.py index 95f79d0ee1c70..573d95cb9e2be 100644 --- a/torch/_inductor/codegen/xpu/cutlass_utils.py +++ b/torch/_inductor/codegen/xpu/cutlass_utils.py @@ -158,12 +158,14 @@ def torch_dtype_to_cutlass_type( assert try_import_cutlass() import cutlass_library # type: ignore[import] - if torch_dtype == torch.bfloat16: - return cutlass_library.library.DataType.bf16 - elif torch_dtype == torch.float: + if torch_dtype == torch.float: return cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 else: - raise NotImplementedError(f"Unsupported data type: {torch_dtype}") + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") def dtype_match( @@ -174,10 +176,18 @@ def dtype_match( assert try_import_cutlass() import cutlass_library - if torch_dtype == torch.bfloat16: - return cutlass_dtype == cutlass_library.library.DataType.bf16 - elif torch_dtype == torch.float: + if torch_dtype == torch.float: return cutlass_dtype == cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return cutlass_dtype == cutlass_library.library.DataType.s8 + elif torch_dtype == torch.uint8: + return cutlass_dtype == cutlass_library.library.DataType.u8 + elif torch_dtype == torch.int32: + return cutlass_dtype == cutlass_library.library.DataType.s32 else: return False diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index 71465140773c4..ab86167da3f50 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -42,6 +42,8 @@ using coord_t = cutlass::gemm::GemmCoord::Index; static cutlass::KernelHardwareInfo hw_info; + // TODO (SYCL) : device_id here is only used for hw info and doesn't necessarly mean + // it's linked to the SYCL queue. It's hardcoded to 0 in the CUDA version as well. const int device_id = 0; if (hw_info.sm_count == 0) { @@ -73,10 +75,14 @@ #endif #endif { + // TODO (SYCL): Pass the SYCL queue (currently last arg of `kernel_call_signature` above) + // once supported on CUTLASS side. Variable name to respect the naming in: _EXTRA_CPP_ARGS (sycl_kernel.py) auto status = gemm_op.initialize(arguments, workspace); CUTLASS_CHECK(status); } { + // TODO (SYCL): Pass the SYCL queue once supported on CUTLASS side. + // Variable name to respect the naming in: _EXTRA_CPP_ARGS (sycl_kernel.py) auto status = gemm_op.run(); CUTLASS_CHECK(status); syclcompat::wait(); @@ -124,6 +130,7 @@ {{epilogue_arguments}}, hw_info }; + // TODO (SYCL) : setup max_swizzle_size in arguments.scheduler once supported """ # Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, @@ -799,11 +806,11 @@ def _dtype_match( assert cutlass_utils.try_import_cutlass() - has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None - # TODO (SYCL) : Extend this once more output dtypes are supported, # AND No source (C) is supported + # has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None + # if op.C.element == DataType.void: # if has_bias: # # op expects no bias, but bias exists From 995a2be5560f5b8355efbf870b7c8d553a82068f Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Thu, 10 Apr 2025 14:05:29 +0100 Subject: [PATCH 8/9] Addressed review comments : disabled swizzle / RT cutlass parameters --- torch/_inductor/codegen/xpu/gemm_template.py | 16 ++++------------ torch/_inductor/codegen/xpu/sycl_template.py | 5 +++-- torch/_inductor/config.py | 3 --- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index ab86167da3f50..e29d24d5ca027 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from typing import Optional, Union -import torch +from torch import float32 from ... import ir from ...config import sycl as inductor_sycl_config @@ -130,7 +130,6 @@ {{epilogue_arguments}}, hw_info }; - // TODO (SYCL) : setup max_swizzle_size in arguments.scheduler once supported """ # Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, @@ -281,11 +280,8 @@ def _add_cutlass_gemm_choices( ops = self.gen_ops() for name, op in ops: - for swizzle in inductor_sycl_config.cutlass_max_profiling_swizzle_options: - description = f"{name} swizzle={swizzle}" - self.maybe_append_choice( - choices, description=description, op=op, swizzle=swizzle - ) + description = f"{name}" # TODO (SYCL): Include RT arg (like swizzling) once supported, in the description and function arguments + self.maybe_append_choice(choices, description=description, op=op) if len(ops) == 0: input_layouts = [node.get_layout() for node in input_nodes] @@ -320,10 +316,6 @@ def header(self) -> IndentedBuffer: #include "cutlass/epilogue/thread/activation.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/tile_scheduler.hpp" - #include "cutlass/tensor_ref.h" - #include "cutlass/util/distribution.h" - #include "cutlass/util/packed_stride.hpp" - #include "cutlass/util/tensor_view_io.h" """ ) return res @@ -658,7 +650,7 @@ def __init__( # TODO (SYCL) : This is a workaround hardcoding output type (layout) to float32 # Should be removed once not limited to the bfloat input->float32 accum cutlass configurations float_layout = copy.deepcopy(layout) - float_layout.dtype = torch.float32 + float_layout.dtype = float32 super().__init__(input_nodes, float_layout, alpha, beta, input_reorder) @staticmethod diff --git a/torch/_inductor/codegen/xpu/sycl_template.py b/torch/_inductor/codegen/xpu/sycl_template.py index f89e450745eea..b304611c4c346 100644 --- a/torch/_inductor/codegen/xpu/sycl_template.py +++ b/torch/_inductor/codegen/xpu/sycl_template.py @@ -75,7 +75,7 @@ def generate( # type: ignore[override] patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), SYCLTemplateKernel( kernel_name=kernel_name, - runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_info=self.get_runtime_arg_info(), # SYCL (TODO) : Currently empty runtime_arg_values=self.get_runtime_arg_values(**kwargs), ) as kernel, ): @@ -266,11 +266,12 @@ def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: @override def get_runtime_arg_info(self) -> list[ArgInfo]: - return [ArgInfo("swizzle", "const uint8_t")] + return [] # TODO (SYCL) : Add relevant RT params once supported (like swizzling) @override def get_runtime_arg_values(self, **kwargs) -> list[Any]: """ Helper method to retrieve runtime args from generate kwargs + # TODO (SYCL) : Currently returning empty list until a RT arg/param is added """ return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c2c7a32a180c7..52bc7d3c19415 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1425,9 +1425,6 @@ class sycl: # By default it's None, so that all CUTLASS configs are tuned. cutlass_max_profiling_configs: Optional[int] = None - # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. - cutlass_max_profiling_swizzle_options: list[int] = [1] # TODO(SYCL): Currently set to 1 value until benchmarking is supported - # TODO (SYCL) : Enable the standalone GEMM runner for testing later generate_test_runner: bool = False From 021deb00ba547e66bac6761dded775bb0a6f847a Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Thu, 10 Apr 2025 16:53:17 +0100 Subject: [PATCH 9/9] Addressed review comments : Template, ops filters & workspace size --- torch/_inductor/autotune_process.py | 9 ++++----- torch/_inductor/codegen/xpu/gemm_template.py | 13 ++++++++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c6b9a42d5589d..28950ec6a3262 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -950,7 +950,7 @@ def __init__( ) -> None: super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) self.source_code = source_code - self.workspace_size: int = 0 + self.workspace_size: int = 0 # TODO (SYCL): workspace size remains 0 self.workspace: Optional[torch.Tensor] = None self.DLL: Optional[DLLWrapper] = None self._workspace_size_updated = False @@ -968,7 +968,7 @@ def make_run_fn( self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor ) -> Callable[[], None]: self.ensure_dll_loaded() - self.update_workspace_size() + self.update_workspace_size() # TODO (SYCL): No effect on workspace_size being unused (remains = 0) args = [ c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [output_tensor] @@ -1006,10 +1006,9 @@ def make_run_fn( def update_workspace_size(self) -> None: if self._workspace_size_updated: return - # Harcoded temporarily for testing with known kernels - self.workspace_size = 4096 # Fixed size for PoC + # TODO (SYCL): Harcoded to zero since no SLM is used on PVC at the moment + self.workspace_size = 0 self._workspace_size_updated = True - # TODO (SYCL) : Implement comprehensive workspace updating mechanism def ensure_dll_loaded(self): if self.DLL is None: diff --git a/torch/_inductor/codegen/xpu/gemm_template.py b/torch/_inductor/codegen/xpu/gemm_template.py index e29d24d5ca027..1f15c08986a88 100644 --- a/torch/_inductor/codegen/xpu/gemm_template.py +++ b/torch/_inductor/codegen/xpu/gemm_template.py @@ -85,7 +85,7 @@ // Variable name to respect the naming in: _EXTRA_CPP_ARGS (sycl_kernel.py) auto status = gemm_op.run(); CUTLASS_CHECK(status); - syclcompat::wait(); + syclcompat::wait_and_throw(); } } catch (std::exception& e) { @@ -447,12 +447,23 @@ def filter_op( X = self.input_nodes[0] W = self.input_nodes[1] + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + if not ( self.layout_match(X.get_layout(), op.A.layout) and self.layout_match(W.get_layout(), op.B.layout) ): return None + # Filter ops by alignment. + if not self._alignment_match(op): + log.debug( + "Skipping due to alignment mismatch. op: %s", op.configuration_name() + ) + return None + # Update op op = copy.deepcopy(op)