Skip to content

Initial support of SYCL CUTLASS for XPU backend through Inductor #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

90 changes: 90 additions & 0 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DLLWrapper,
get_hash,
PyCodeCache,
SYCLCodeCache,
)
from torch._inductor.utils import get_gpu_type, is_gpu
from torch._logging import getArtifactLogger
Expand Down Expand Up @@ -935,6 +936,95 @@ def __str__(self) -> str:
return f"{self.kernel_name=}"


class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put Tensors in here!
# TODO (SYCL) : Complete the bmrq class to enable full autotuning
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
self.workspace_size: int = 0 # TODO (SYCL): workspace size remains 0
self.workspace: Optional[torch.Tensor] = None
self.DLL: Optional[DLLWrapper] = None
self._workspace_size_updated = False
self.hash_key: str = ""
self.source_file: str = ""
self.hash_key, self.source_file = SYCLCodeCache.write(self.source_code, "so")

def precompile(self):
# Prepopulate SYCLCodeCache
autotuning_log.debug("Precompiling %s", self)
SYCLCodeCache.compile(self.source_code, "so")
autotuning_log.debug("Done precompiling %s", self)

def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
self.ensure_dll_loaded()
self.update_workspace_size() # TODO (SYCL): No effect on workspace_size being unused (remains = 0)
args = [
c_void_p(tensor.data_ptr())
for tensor in list(input_tensors) + [output_tensor]
]
autotuning_log.debug(
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
self.kernel_name,
self.source_file,
self.hash_key,
self.DLL,
args,
self.extra_args,
)
queue_ptr = c_void_p(torch.xpu.current_stream().sycl_queue)
run_method = getattr(self.DLL, self.kernel_name)
workspace_ptr = c_void_p(0)
if self.workspace_size > 0:
self.workspace = torch.zeros(
(self.workspace_size + 7) // 8,
dtype=torch.float64,
device=output_tensor.device,
)
workspace_ptr = c_void_p(self.workspace.data_ptr())

# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
None, # null workspace size ptr
workspace_ptr, # set workspace ptr,
queue_ptr,
)

def update_workspace_size(self) -> None:
if self._workspace_size_updated:
return
# TODO (SYCL): Harcoded to zero since no SLM is used on PVC at the moment
self.workspace_size = 0
self._workspace_size_updated = True

def ensure_dll_loaded(self):
if self.DLL is None:
self.DLL, self.hash_key, self.source_file = SYCLCodeCache.load(
self.source_code, "so"
)

def cleanup_run_fn(self) -> None:
if self.DLL is not None:
self.DLL.close()
self.workspace = None

def __str__(self) -> str:
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"


def benchmark_in_sub_process(
choices: list[TritonTemplateCaller],
) -> dict[TritonTemplateCaller, float]:
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def init_backend_registration() -> None:
from .mps import MetalScheduling
from .triton import TritonScheduling
from .wrapper import PythonWrapperCodegen
from .xpu_combined_scheduling import SYCLCombinedScheduling

if get_scheduling_for_device("cpu") is None:
cpu_backends = {
Expand Down Expand Up @@ -424,9 +425,10 @@ def init_backend_registration() -> None:
)

if get_scheduling_for_device("xpu") is None:
# SYCLCombinedScheduling combines Triton and SYCL C++ scheduling for XPU devices via delegation
register_backend_for_device(
"xpu",
TritonScheduling,
SYCLCombinedScheduling,
PythonWrapperCodegen,
CppWrapperGpu,
)
Expand Down
263 changes: 263 additions & 0 deletions torch/_inductor/codegen/xpu/cutlass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# mypy: allow-untyped-defs
import functools
import logging
import os
import sys
from dataclasses import dataclass
from typing import Any, Optional

import sympy

import torch
from torch._inductor.utils import clear_on_fresh_inductor_cache

from ... import config
from ...ir import Layout
from ...runtime.runtime_utils import cache_dir
from ...virtualized import V


log = logging.getLogger(__name__)


@functools.lru_cache(None)
def try_import_cutlass() -> bool:
"""
Currently only supporting user specified cutlass_dir or falling to the
default ../third_party/cutlass/ (build from source setups).
"""
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.

cutlass_py_full_path = os.path.abspath(
os.path.join(config.cutlass_dir, "python/cutlass_library")
)
tmp_cutlass_py_full_path = os.path.abspath(
os.path.join(cache_dir(), "torch_cutlass_library")
)
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")

if os.path.isdir(cutlass_py_full_path):
if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link):
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
else:
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
os.symlink(cutlass_py_full_path, dst_link)
sys.path.append(tmp_cutlass_py_full_path)
try:
import cutlass_library.generator # noqa: F401
import cutlass_library.library # noqa: F401
import cutlass_library.manifest # noqa: F401

return True
except ImportError as e:
log.debug(
"Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.",
str(e),
)
else:
log.debug(
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
cutlass_py_full_path,
)
return False


@functools.lru_cache(8)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need a cache?

def _normalize_sycl_arch(arch: str) -> str:
if int(arch) == 11:
return "11"
else:
raise NotImplementedError(f"Unsupported sycl arch: {arch}")


@dataclass
class CUTLASSArgs:
"""
CUTLASS args used to initialize a CUTLASS Manifest.
"""

architectures: Optional[str] = None
cuda_version: Optional[str] = None # Unused in generator.py for PVC
instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC

operations = "all"
build_dir = ""
curr_build_dir = ""
generator_target = ""
kernels = "all"
ignore_kernels = ""
exclude_kernels = ""
# UNUSED at the moment, part of Manifest class in cutlass_library
kernel_filter_file: None = None
selected_kernel_list: None = None
interface_dir: None = None
filter_by_cc = False
disable_full_archs_compilation = False

def __post_init__(self):
if self.architectures is None:
raise RuntimeError(f"{self.architectures=} is None!")
self.architectures = _normalize_sycl_arch(self.architectures)


@clear_on_fresh_inductor_cache
@functools.lru_cache(None)
def _gen_ops_cached(arch) -> list[Any]:
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library.generator as cutlass_generator
import cutlass_library.manifest as cutlass_manifest

if arch is None:
log.error(
"Cannot detect XPU arch %s."
"Will discard all cutlass ops. "
"Please consider setting _inductor.xpu.arch",
arch,
)
return []
arch = _normalize_sycl_arch(arch)

sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC

args = CUTLASSArgs(
architectures=arch,
instantiation_level="0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py
cuda_version=sycl_version,
)
manifest = cutlass_manifest.Manifest(args)

if arch == "11":
cutlass_generator.GeneratePVC(manifest, sycl_version)
else:
log.error("Invalid XPU arch")
return []
return manifest.operations


def gen_ops() -> list[Any]:
"""
Generates all supported CUTLASS operations.
"""
# Currently limited to PVC (arch 1100), harcoding arch
# TODO :(SYCL) get_xpu_arch()
arch = "11"
return _gen_ops_cached(arch)


def torch_dtype_to_cutlass_type(
torch_dtype: torch.dtype,
) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library # type: ignore[import]

if torch_dtype == torch.float:
return cutlass_library.library.DataType.f32
elif torch_dtype == torch.half:
return cutlass_library.library.DataType.f16
elif torch_dtype == torch.bfloat16:
return cutlass_library.library.DataType.bf16
else:
raise NotImplementedError(f"Unsupported data type: {torch_dtype=}")


def dtype_match(
torch_dtype: Optional[torch.dtype],
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
) -> bool:
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library

if torch_dtype == torch.float:
return cutlass_dtype == cutlass_library.library.DataType.f32
elif torch_dtype == torch.half:
return cutlass_dtype == cutlass_library.library.DataType.f16
elif torch_dtype == torch.bfloat16:
return cutlass_dtype == cutlass_library.library.DataType.bf16
elif torch_dtype == torch.int8:
return cutlass_dtype == cutlass_library.library.DataType.s8
elif torch_dtype == torch.uint8:
return cutlass_dtype == cutlass_library.library.DataType.u8
elif torch_dtype == torch.int32:
return cutlass_dtype == cutlass_library.library.DataType.s32
else:
return False


def get_accumulator_dtype(
input_torch_dtypes: list[torch.dtype],
) -> Optional[torch.dtype]:
"""
Given a pair of input torch dtypes, returns the inferred accumulator torch dtype.
"""
# TODO (SYCL) : Extend this once other accumulator & input types are supported
if len(input_torch_dtypes) != 2:
return None

if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes):
return torch.float
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO I think

else:
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}")


def get_alignments(torch_dtype: torch.dtype) -> list[int]:
"""
Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype.
"""
# TODO (SYCL): Extend for other types & double-check alignments
if torch_dtype == torch.bfloat16:
return [8, 4, 2, 1]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this in bytes? I would guess bf16 needs to be atleast 2 and float needs to be atleast 4.

elif torch_dtype == torch.float:
return [4, 2, 1]
else:
raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")


def get_max_alignment(inductor_layout: Layout) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently identical to the CUDA version. Do we expect any differences here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, didn't get the chance to check it further but will discuss this with @aacostadiaz and the team to make sure.

"""
Returns the max alignment (in terms of number of elements) for a given Inductor Layout.
"""

dtype = inductor_layout.dtype
size = inductor_layout.size
offset = inductor_layout.offset

def is_static_int(number):
return isinstance(number, (int, sympy.Integer))

def a_factor_of(x, alignment):
if is_static_int(x) and is_static_int(alignment):
return x % alignment == 0
rem = sympy.Mod(x, alignment)
return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0))

try:
contiguous_dim = inductor_layout.stride.index(1)
except ValueError:
# No dim with stride 1 found, return 1
return 1
alignments = get_alignments(dtype)
for alignment in alignments:
if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of(
offset, alignment
):
Comment on lines +249 to +251
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Formatting is a bit weird here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did run the linter on it but it doesn't make change. CUDA version has the same.

continue
if all(
(dim == contiguous_dim)
or a_factor_of(inductor_layout.stride[dim], alignment)
for dim in range(len(size))
):
return alignment
return 1


# TODO (SYCL) : Add helpers for CUTLASS kernels testing & benchmarking once standalone
# runner is enabled.
Loading