-
Notifications
You must be signed in to change notification settings - Fork 2
Initial support of SYCL CUTLASS for XPU backend through Inductor #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial support of SYCL CUTLASS for XPU backend through Inductor #2
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall implementation looks good, also based on the changes made in comparison to the CUDA version of these files.
I've added some small comments and a few questions inline.
As I've found multiple cases of trailing whitespace and wrong ordering of imports, I think these files have not run through a formatter yet. I think it's important to do that before adding them into the repository to avoid unrelated formatting changes on future PRs.
assert os.path.islink( | ||
dst_link | ||
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that formatting here deviates from the CUDA version of this file.
I don't think that's a problem in general, but made me wonder whether you ran the Pytorch formatters on these files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I'll run the proper pytorch formatter next. I did run black before with default config once on these new files which might explain the divergence.
|
||
args = CUTLASSArgs( | ||
architectures=arch, | ||
instantiation_level = "0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the instantiation level express?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The instantiation level is a 4 digits number used to control the number of randomly generated configurations from cutlass_library side, it's not used by GeneratePVC yet as we do an explicit cartesian product with known configs, while it's used for SM90 etc.. More about it here : https://github.com/codeplaysoftware/cutlass-fork/blob/041d78b4d8c30722b2c2e14e858114cca273b6d7/python/cutlass_library/manifest.py#L575
def torch_dtype_to_cutlass_type( | ||
torch_dtype: torch.dtype, | ||
) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 | ||
# Import cutlass python scripts. | ||
assert try_import_cutlass() | ||
import cutlass_library # type: ignore[import] | ||
|
||
if torch_dtype == torch.bfloat16: | ||
return cutlass_library.library.DataType.bf16 | ||
elif torch_dtype == torch.float: | ||
return cutlass_library.library.DataType.f32 | ||
else: | ||
raise NotImplementedError(f"Unsupported data type: {torch_dtype}") | ||
|
||
|
||
def dtype_match( | ||
torch_dtype: Optional[torch.dtype], | ||
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 | ||
) -> bool: | ||
# Import cutlass python scripts. | ||
assert try_import_cutlass() | ||
import cutlass_library | ||
|
||
if torch_dtype == torch.bfloat16: | ||
return cutlass_dtype == cutlass_library.library.DataType.bf16 | ||
elif torch_dtype == torch.float: | ||
return cutlass_dtype == cutlass_library.library.DataType.f32 | ||
else: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to keep the same order of cases in the if
as the corresponding CUDA version of this file, We currently have less supported cases, but with the same order, comparability of files is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted will restore the order & cases, it won't harm I guess, we only need to be careful with the examples we're running.
raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") | ||
|
||
|
||
def get_max_alignment(inductor_layout: Layout) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently identical to the CUDA version. Do we expect any differences here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, didn't get the chance to check it further but will discuss this with @aacostadiaz and the team to make sure.
from . import cutlass_utils | ||
from .sycl_kernel import SYCLTemplateKernel | ||
from .sycl_template import CUTLASSTemplate | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this import here? Couldn't we use relative imports instead?
torch/_inductor/utils.py
Outdated
gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) | ||
if gemm_size <= 0 or gemm_size < config.sycl.cutlass_backend_min_gemm_size: | ||
return False | ||
from .codegen.xpu.cutlass_utils import try_import_cutlass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this import closer to the actual use of try_import_cutlass
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of additional minor comments, some of my previous comments are also still open.
if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( | ||
offset, alignment | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Formatting is a bit weird here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did run the linter on it but it doesn't make change. CUDA version has the same.
*workspace_size = gemm_op.get_workspace_size(arguments); | ||
return 0; | ||
} | ||
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we actually checking for null pointers here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through the next block and didn't really find any "real" valid ptrs checking, but it might be somewhere down the callstack. I guess the goal here is just to make sure the ordering of if (workspace_size)
block and the can_implement(arguments)
is not switched by mistake eventhough it can be without causing problem when the if condition is false (which is always the case during execution as None is passed here : https://github.com/OuadiElfarouki/pytorch/blob/cc172171ea1bafe2138ea741fe30b00f12609bd8/torch/_inductor/codegen/xpu/sycl_kernel.py#L305
CUTLASS_CHECK(status); | ||
} | ||
{ | ||
auto status = gemm_op.run(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the second point is still open: Is there downsides here of using run
instead of operator()
?
|
||
if op.gemm_kind not in self._get_supported_ops(): | ||
return None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still open.
and self.layout_match(W.get_layout(), op.B.layout) | ||
): | ||
return None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still open.
""" | ||
sizes = self.output_node.get_size() | ||
if len(sizes) > 2: | ||
return "cutlass::gemm::GemmUniversalMode::kBatched" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does batched GEMM currently work with our CUTLASS and CollectiveBuilder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have it enabled yet, @aacostadiaz would confirm.
return False | ||
|
||
|
||
@functools.lru_cache(8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need a cache?
and _use_autotune_backend("CUTLASS") | ||
) | ||
|
||
from .codegen.xpu.cutlass_utils import try_import_cutlass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be after if res
to avoid running it when not necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree just wanted to respect the usual way modules are imported in pytorch & that seems to be at the top of functions/classes scope rather than within conditional blocks even when unused for sure.
torch/_inductor/config.py
Outdated
cutlass_max_profiling_configs: Optional[int] = None | ||
|
||
# The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. | ||
cutlass_max_profiling_swizzle_options: list[int] = [1] # TODO(SYCL): Currently set to 1 value until benchmarking is supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we discussed that this is unlikely to make a difference since we are not using SLM, maybe you should add a comment about that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes will do.
CUTLASS_CHECK(status); | ||
} | ||
{ | ||
auto status = gemm_op.run(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#endif | ||
#endif | ||
{ | ||
auto status = gemm_op.initialize(arguments, workspace); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can still add the queue now, it just might not be used yet. The interface takes a sycl::queue*
cast to void*
. The PR to fix that is just waiting for review.
import cutlass_library.library as cutlass_lib # noqa: F401 | ||
|
||
assert isinstance(op, cutlass_gemm_op.GemmOperation), ( | ||
"op argument is required and has to be an instance of GemmOperation" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it defaulted to None then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes quite confusing but I believe it has to do with the inheritance from the SYCLTemplate
class which doesn't have an op
argument in its render
method, so kind of keeping the inherited & overrided interface (# type: ignore[override]
) consistent in declaration but forcing the assertion at execution time (hence the assert on op type).
if len(A_size) < 2: | ||
A_size.insert(0, 1) | ||
if len(B_size) < 2: | ||
A_size.insert(1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this supposed to be A_size again?
torch/_inductor/autotune_process.py
Outdated
if self._workspace_size_updated: | ||
return | ||
# Harcoded temporarily for testing with known kernels | ||
self.workspace_size = 4096 # Fixed size for PoC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pvc gemm implementation don't use SLM, so this could be 0
return None | ||
|
||
if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes): | ||
return torch.float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a TODO I think
""" | ||
# TODO (SYCL): Extend for other types & double-check alignments | ||
if torch_dtype == torch.bfloat16: | ||
return [8, 4, 2, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this in bytes? I would guess bf16 needs to be atleast 2 and float needs to be atleast 4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving this PR. There are some open questions, please address them or at least reply to them in the comments for future reference when merging the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all your work, Ouie
Summary : This patch enables an initial execution of
torch.mm
through an inductor generated SYCL CUTLASS kernel for intel PVC.Following the reference CUDA implementation, this implements the following key functionalities :
For Template Generation & rendering :
SYCLTemplate
,CUTLASSTemplate
,CUTLASSGemmTemplate
,CUTLASS3xGemmTemplate
: Handles generating the full c++ code from the call toGeneratePVC
to get theGemmOperations
(exposed by cutlass_library), filtering the operations, constructing the Manifest & extracting the gemm instance from the emitter (exposed by cutlass_library) until the full wrapping of the c++ template code using runtime arguments & final kernel launch.cutlass_utils.py
: utility file containing relevant functions used across the codegen process.SYCLKernel
,SYCLTemplateKernel
: Handles higher level kernel template and kernel calling from host side. Used within the previous Template classes.For Autotuning :
SYCLBenchmarkRequest
: Currently added as an almost dummy, not really benchmarking since we're selecting a single generated configuration for this PoC.For wrapping & triggering the above :
SYCLTemplateCaller
: Wrapper holding a ready to compile, execute, benchmark SYCL Template Kernel. This is the higher level construct that's added to the list of "choices" in the autotuning process for selecting the best configuration.For scheduling/Execution :
SYCLCPPScheduling
&SYCLCombinedScheduling
: Orchestrator of kernel calls across eventually nodes with different lowerings (Triton & CUTLASS SYCL for instance). Few changes have been made to this compared to original CUDA implementation.Current state was fine-tuned to support the only type configuration exposed by cutlass on PVC so far, a.k.a
bfloat16
input andfp32
accumulation, forcing some workarounds on pytorch side, namely related to D(layout/output node) & C(source/input_node[2]) dtypes.Unsupported features &/or partially implemented ones are highlighted as comments
TODO (SYCL)
.