Skip to content

Commit f0d0042

Browse files
coconutrubenpytorchmergebot
authored andcommitted
[inductor][ck] kBatch filtering with gen_ops (pytorch#148004)
Summary: # Why not all choices of kBatch are valid and will lead to a runtime error (when CK checks the validity of the args) https://github.com/ROCm/composable_kernel/blob/c9bcfd755ed4d2102d76a6f545ac6e9a030d7d8e/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp#L1020 # What - move kBatch inside the gen_ops to have more control over it, and be able to filter it - expand filtering based on the cpp logic - refactor the padding checks to be more readable Test Plan: ``` buck2 run -c fbcode.re_gpu_tests=False mode/opt-amd-gpu fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0 ``` with kBatch = 128: some filering kBatch = 1: no filering kBatch = 1738: all options filtered out Reviewed By: henrylhtsang Differential Revision: D70211442 Pull Request resolved: pytorch#148004 Approved by: https://github.com/ColinPeppler, https://github.com/tenpercent
1 parent ce805a5 commit f0d0042

File tree

1 file changed

+134
-18
lines changed

1 file changed

+134
-18
lines changed

torch/_inductor/codegen/rocm/ck_universal_gemm_template.py

+134-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type"
22
import copy
33
import logging
4+
import math
45
import random
6+
from collections import namedtuple
57
from typing import Optional
68

79
import sympy
@@ -22,6 +24,30 @@
2224

2325
log = logging.getLogger(__name__)
2426

27+
# lightweight collection of information about a single op
28+
InductorROCmOp = namedtuple("InductorROCmOp", ["op", "kBatch"])
29+
30+
padding_lookup = {
31+
"M": {
32+
"GemmSpecialization::MPadding": True,
33+
"GemmSpecialization::MNPadding": True,
34+
"GemmSpecialization::MKPadding": True,
35+
"GemmSpecialization::MNKPadding": True,
36+
},
37+
"N": {
38+
"GemmSpecialization::NPadding": True,
39+
"GemmSpecialization::MNPadding": True,
40+
"GemmSpecialization::NKPadding": True,
41+
"GemmSpecialization::MNKPadding": True,
42+
},
43+
"K": {
44+
"GemmSpecialization::KPadding": True,
45+
"GemmSpecialization::MKPadding": True,
46+
"GemmSpecialization::NKPadding": True,
47+
"GemmSpecialization::MNKPadding": True,
48+
},
49+
}
50+
2551

2652
def is_static_int(number):
2753
return isinstance(number, (int, sympy.Integer))
@@ -363,7 +389,14 @@ def inline_utils(self):
363389
)
364390
return res
365391

366-
def filter_op(self, op: "CKGemmOperation"):
392+
def _has_padding(self, dimension, gemm_specialization):
393+
# Get the relevant padding map for the given dimension
394+
dimension_padding = padding_lookup.get(dimension, {})
395+
396+
# Check if the specialization is in the dimension's padding map
397+
return dimension_padding.get(gemm_specialization, False)
398+
399+
def filter_op(self, op_info: InductorROCmOp):
367400
"""
368401
Determines whether a given op definition is suitable for the current
369402
input / output of the operation that this template implements.
@@ -372,6 +405,7 @@ def filter_op(self, op: "CKGemmOperation"):
372405
373406
Returns None if the op is not suitable, otherwise returns the op to be used.
374407
"""
408+
op, kBatch = op_info.op, op_info.kBatch
375409
metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]]
376410
X_meta = metas[0]
377411
W_meta = metas[1]
@@ -398,26 +432,27 @@ def filter_op(self, op: "CKGemmOperation"):
398432
N = W_meta.size[-1]
399433

400434
if is_static_int(M):
401-
if not any(
402-
m_padding in op.gemm_specialization
403-
for m_padding in ["MPadding", "MNPadding", "MKPadding", "MNKPadding"]
404-
):
435+
if not self._has_padding("M", op.gemm_specialization):
405436
if M % op.m_per_block != 0:
406437
return None
407438
if is_static_int(N):
408-
if not any(
409-
n_padding in op.gemm_specialization
410-
for n_padding in ["NPadding", "MNPadding", "NKPadding", "MNKPadding"]
411-
):
439+
if not self._has_padding("N", op.gemm_specialization):
412440
if N % op.n_per_block != 0:
413441
return None
414442
if is_static_int(K):
415-
if not any(
416-
k_padding in op.gemm_specialization
417-
for k_padding in ["KPadding", "MKPadding", "NKPadding", "MNKPadding"]
418-
):
443+
if not self._has_padding("K", op.gemm_specialization):
419444
if K % op.k_per_block != 0:
420445
return None
446+
K_t = kBatch * op.k_per_block
447+
if K % K_t != 0:
448+
return None
449+
else:
450+
# need another kBatch check here
451+
lcm = abs(op.a_k1 * op.b_k1) // math.gcd(op.a_k1, op.b_k1)
452+
K_t = kBatch * lcm
453+
k_read_pad_splited = math.ceil(K / K_t) * lcm
454+
if (k_read_pad_splited * (kBatch - 1)) >= K:
455+
return None
421456

422457
a_contig_size = (
423458
K if op.a_layout == "Row" else M if op.a_layout == "Col" else None
@@ -451,12 +486,83 @@ def filter_op(self, op: "CKGemmOperation"):
451486
!= 0
452487
):
453488
return None
454-
489+
if not self._check_num_k_loops(op, kBatch):
490+
return None
455491
# TBD disable instances with invalid number of pipeline prefetch stages
456492
# It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check
457493

458494
return op
459495

496+
def _check_num_k_loops(self, op, kBatch):
497+
# Additional splitK scenario check
498+
metas = [T.get_layout() for T in [*self.input_nodes]]
499+
X_meta = metas[0]
500+
W_meta = metas[1]
501+
K = X_meta.size[-1]
502+
if kBatch > 1:
503+
if op.block_gemm_pipeline_version != "BlockGemmPipelineVersion::v1":
504+
try:
505+
prefetch_stages = self._prefetch_stages(
506+
op,
507+
torch.empty((), dtype=X_meta.dtype).element_size(),
508+
torch.empty((), dtype=W_meta.dtype).element_size(),
509+
torch.cuda.get_device_properties(X_meta.device).warp_size,
510+
)
511+
except Exception as e:
512+
log.debug(
513+
"Failed to prefetch_stages for %s with exception %s", op.name, e
514+
)
515+
# be conservative here and disable the op
516+
return False
517+
518+
K_t = op.k_per_block * kBatch
519+
ak0 = (K + K_t - 1) // K_t * (op.k_per_block // op.a_k1)
520+
num_k_loop = ak0 // (op.k_per_block // op.a_k1)
521+
if num_k_loop <= prefetch_stages:
522+
log.debug(
523+
"Op %s is not compatible due to invalid number of pipeline prefetch stages. "
524+
"Parameters: kBatch=%s, block_gemm_pipeline_version=%s, prefetch_stages=%s, num_k_loop=%s",
525+
op.name(),
526+
kBatch,
527+
op.block_gemm_pipeline_version,
528+
prefetch_stages,
529+
num_k_loop,
530+
)
531+
return False
532+
533+
return True
534+
535+
# small helper to figure out the prefetch stages on AMD
536+
def _prefetch_stages(self, op, a_dtype_size, b_dtype_size, warp_size: int = 64):
537+
version_str = op.block_gemm_pipeline_version.split("::")[-1]
538+
try:
539+
version = int(version_str[1:]) # Assuming the format is always 'vX'
540+
except ValueError as e:
541+
raise ValueError(f"Invalid version string: {version_str}") from e
542+
if version not in [1, 2, 3, 4, 5]:
543+
raise ValueError(
544+
f"unknown prefetch stages for {op.block_gemm_pipeline_version}"
545+
)
546+
# Define the mapping of versions to stages
547+
version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3}
548+
# Get the stages for the given version
549+
stages = version_to_stages.get(version, None)
550+
if stages is None:
551+
# This means we're at stage 2, and this requires computation
552+
# See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950
553+
wgp_per_cu = max(4 * warp_size // op.block_size, 1)
554+
full_mem_band_prefetch_stages = math.ceil(
555+
32768
556+
/ wgp_per_cu
557+
/ (
558+
(op.m_per_block * a_dtype_size + op.n_per_block * b_dtype_size)
559+
* op.k_per_block
560+
)
561+
)
562+
stages = min(max(full_mem_band_prefetch_stages, 2), 8)
563+
564+
return stages
565+
460566
def emit_ck_instance(self, op: "CKGemmOperation"):
461567
# The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance
462568
struct_name = (
@@ -765,7 +871,7 @@ def _is_rcr_f16(self):
765871
and Y_layout == "Row"
766872
)
767873

768-
def gen_ops(self):
874+
def gen_ops(self) -> list[InductorROCmOp]:
769875
"""
770876
Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents.
771877
The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments.
@@ -794,7 +900,17 @@ def gen_ops(self):
794900

795901
assert generator is not None
796902

797-
filtered_instances = list(filter(lambda op: self.filter_op(op), generator()))
903+
# NOTE(coconutruben): for now, we only support kBatch 1
904+
# TODO(coconturuben): infer a better kBatch depending on the input shape
905+
# TODO(coconutruben): allow users to provide a list of kBatches to sweep over
906+
kBatches = [1]
907+
rops = generator()
908+
ops = [
909+
InductorROCmOp(op=op, kBatch=kBatch) for op in rops for kBatch in kBatches
910+
]
911+
912+
filtered_instances = list(filter(lambda op: self.filter_op(op), ops))
913+
798914
# NB: when using a fixed list order, most likely we will pick the subset of instances
799915
# which are very similar to each other. Randomizing the choice seems to solve this.
800916
random.seed(-11)
@@ -836,8 +952,8 @@ def add_ck_gemm_choices(
836952
for op in ops:
837953
template.maybe_append_choice(
838954
choices,
839-
op=op,
840-
kBatch=1,
955+
op=op.op,
956+
kBatch=op.kBatch,
841957
)
842958

843959
def size_args(self):

0 commit comments

Comments
 (0)