Skip to content

Initial support of SYCL CUTLASS for XPU backend through Inductor #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

OuadiElfarouki
Copy link

@OuadiElfarouki OuadiElfarouki commented Apr 4, 2025

Summary : This patch enables an initial execution of torch.mm through an inductor generated SYCL CUTLASS kernel for intel PVC.
Following the reference CUDA implementation, this implements the following key functionalities :

  • For Template Generation & rendering :

    • SYCLTemplate, CUTLASSTemplate, CUTLASSGemmTemplate, CUTLASS3xGemmTemplate : Handles generating the full c++ code from the call to GeneratePVC to get the GemmOperations (exposed by cutlass_library), filtering the operations, constructing the Manifest & extracting the gemm instance from the emitter (exposed by cutlass_library) until the full wrapping of the c++ template code using runtime arguments & final kernel launch.
    • cutlass_utils.py : utility file containing relevant functions used across the codegen process.
    • SYCLKernel, SYCLTemplateKernel : Handles higher level kernel template and kernel calling from host side. Used within the previous Template classes.
  • For Autotuning :

    • SYCLBenchmarkRequest : Currently added as an almost dummy, not really benchmarking since we're selecting a single generated configuration for this PoC.
  • For wrapping & triggering the above :

    • SYCLTemplateCaller : Wrapper holding a ready to compile, execute, benchmark SYCL Template Kernel. This is the higher level construct that's added to the list of "choices" in the autotuning process for selecting the best configuration.
  • For scheduling/Execution :

    • SYCLCPPScheduling & SYCLCombinedScheduling : Orchestrator of kernel calls across eventually nodes with different lowerings (Triton & CUTLASS SYCL for instance). Few changes have been made to this compared to original CUDA implementation.

Current state was fine-tuned to support the only type configuration exposed by cutlass on PVC so far, a.k.a bfloat16 input and fp32 accumulation, forcing some workarounds on pytorch side, namely related to D(layout/output node) & C(source/input_node[2]) dtypes.

Unsupported features &/or partially implemented ones are highlighted as comments TODO (SYCL).

Copy link
Collaborator

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall implementation looks good, also based on the changes made in comparison to the CUDA version of these files.

I've added some small comments and a few questions inline.

As I've found multiple cases of trailing whitespace and wrong ordering of imports, I think these files have not run through a formatter yet. I think it's important to do that before adding them into the repository to avoid unrelated formatting changes on future PRs.

Comment on lines 42 to 44
assert os.path.islink(
dst_link
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that formatting here deviates from the CUDA version of this file.
I don't think that's a problem in general, but made me wonder whether you ran the Pytorch formatters on these files?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll run the proper pytorch formatter next. I did run black before with default config once on these new files which might explain the divergence.


args = CUTLASSArgs(
architectures=arch,
instantiation_level = "0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the instantiation level express?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The instantiation level is a 4 digits number used to control the number of randomly generated configurations from cutlass_library side, it's not used by GeneratePVC yet as we do an explicit cartesian product with known configs, while it's used for SM90 etc.. More about it here : https://github.com/codeplaysoftware/cutlass-fork/blob/041d78b4d8c30722b2c2e14e858114cca273b6d7/python/cutlass_library/manifest.py#L575

Comment on lines 155 to 183
def torch_dtype_to_cutlass_type(
torch_dtype: torch.dtype,
) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library # type: ignore[import]

if torch_dtype == torch.bfloat16:
return cutlass_library.library.DataType.bf16
elif torch_dtype == torch.float:
return cutlass_library.library.DataType.f32
else:
raise NotImplementedError(f"Unsupported data type: {torch_dtype}")


def dtype_match(
torch_dtype: Optional[torch.dtype],
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
) -> bool:
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library

if torch_dtype == torch.bfloat16:
return cutlass_dtype == cutlass_library.library.DataType.bf16
elif torch_dtype == torch.float:
return cutlass_dtype == cutlass_library.library.DataType.f32
else:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to keep the same order of cases in the if as the corresponding CUDA version of this file, We currently have less supported cases, but with the same order, comparability of files is better.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted will restore the order & cases, it won't harm I guess, we only need to be careful with the examples we're running.

raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")


def get_max_alignment(inductor_layout: Layout) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently identical to the CUDA version. Do we expect any differences here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, didn't get the chance to check it further but will discuss this with @aacostadiaz and the team to make sure.

from . import cutlass_utils
from .sycl_kernel import SYCLTemplateKernel
from .sycl_template import CUTLASSTemplate
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this import here? Couldn't we use relative imports instead?

gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
if gemm_size <= 0 or gemm_size < config.sycl.cutlass_backend_min_gemm_size:
return False
from .codegen.xpu.cutlass_utils import try_import_cutlass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this import closer to the actual use of try_import_cutlass?

Copy link
Collaborator

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of additional minor comments, some of my previous comments are also still open.

Comment on lines +249 to +251
if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of(
offset, alignment
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Formatting is a bit weird here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did run the linter on it but it doesn't make change. CUDA version has the same.

*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we actually checking for null pointers here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the next block and didn't really find any "real" valid ptrs checking, but it might be somewhere down the callstack. I guess the goal here is just to make sure the ordering of if (workspace_size) block and the can_implement(arguments) is not switched by mistake eventhough it can be without causing problem when the if condition is false (which is always the case during execution as None is passed here : https://github.com/OuadiElfarouki/pytorch/blob/cc172171ea1bafe2138ea741fe30b00f12609bd8/torch/_inductor/codegen/xpu/sycl_kernel.py#L305

CUTLASS_CHECK(status);
}
{
auto status = gemm_op.run();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the second point is still open: Is there downsides here of using run instead of operator()?


if op.gemm_kind not in self._get_supported_ops():
return None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still open.

and self.layout_match(W.get_layout(), op.B.layout)
):
return None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still open.

"""
sizes = self.output_node.get_size()
if len(sizes) > 2:
return "cutlass::gemm::GemmUniversalMode::kBatched"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does batched GEMM currently work with our CUTLASS and CollectiveBuilder?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have it enabled yet, @aacostadiaz would confirm.

return False


@functools.lru_cache(8)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need a cache?

and _use_autotune_backend("CUTLASS")
)

from .codegen.xpu.cutlass_utils import try_import_cutlass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be after if res to avoid running it when not necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree just wanted to respect the usual way modules are imported in pytorch & that seems to be at the top of functions/classes scope rather than within conditional blocks even when unused for sure.

cutlass_max_profiling_configs: Optional[int] = None

# The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune.
cutlass_max_profiling_swizzle_options: list[int] = [1] # TODO(SYCL): Currently set to 1 value until benchmarking is supported
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we discussed that this is unlikely to make a difference since we are not using SLM, maybe you should add a comment about that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes will do.

CUTLASS_CHECK(status);
}
{
auto status = gemm_op.run();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still add the queue now, it just might not be used yet. The interface takes a sycl::queue* cast to void*. The PR to fix that is just waiting for review.

import cutlass_library.library as cutlass_lib # noqa: F401

assert isinstance(op, cutlass_gemm_op.GemmOperation), (
"op argument is required and has to be an instance of GemmOperation"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it defaulted to None then?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes quite confusing but I believe it has to do with the inheritance from the SYCLTemplate class which doesn't have an op argument in its render method, so kind of keeping the inherited & overrided interface (# type: ignore[override]) consistent in declaration but forcing the assertion at execution time (hence the assert on op type).

if len(A_size) < 2:
A_size.insert(0, 1)
if len(B_size) < 2:
A_size.insert(1, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to be A_size again?

if self._workspace_size_updated:
return
# Harcoded temporarily for testing with known kernels
self.workspace_size = 4096 # Fixed size for PoC
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pvc gemm implementation don't use SLM, so this could be 0

return None

if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes):
return torch.float
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO I think

"""
# TODO (SYCL): Extend for other types & double-check alignments
if torch_dtype == torch.bfloat16:
return [8, 4, 2, 1]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this in bytes? I would guess bf16 needs to be atleast 2 and float needs to be atleast 4.

Copy link
Collaborator

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving this PR. There are some open questions, please address them or at least reply to them in the comments for future reference when merging the PR.

Copy link

@FMarno FMarno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all your work, Ouie

@sommerlukas sommerlukas merged commit a47e05c into codeplaysoftware:sycl-develop Apr 17, 2025
47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants