1
1
# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type"
2
2
import copy
3
3
import logging
4
+ import math
4
5
import random
6
+ from collections import namedtuple
5
7
from typing import Optional
6
8
7
9
import sympy
22
24
23
25
log = logging .getLogger (__name__ )
24
26
27
+ # lightweight collection of information about a single op
28
+ InductorROCmOp = namedtuple ("InductorROCmOp" , ["op" , "kBatch" ])
29
+
30
+ padding_lookup = {
31
+ "M" : {
32
+ "GemmSpecialization::MPadding" : True ,
33
+ "GemmSpecialization::MNPadding" : True ,
34
+ "GemmSpecialization::MKPadding" : True ,
35
+ "GemmSpecialization::MNKPadding" : True ,
36
+ },
37
+ "N" : {
38
+ "GemmSpecialization::NPadding" : True ,
39
+ "GemmSpecialization::MNPadding" : True ,
40
+ "GemmSpecialization::NKPadding" : True ,
41
+ "GemmSpecialization::MNKPadding" : True ,
42
+ },
43
+ "K" : {
44
+ "GemmSpecialization::KPadding" : True ,
45
+ "GemmSpecialization::MKPadding" : True ,
46
+ "GemmSpecialization::NKPadding" : True ,
47
+ "GemmSpecialization::MNKPadding" : True ,
48
+ },
49
+ }
50
+
25
51
26
52
def is_static_int (number ):
27
53
return isinstance (number , (int , sympy .Integer ))
@@ -363,7 +389,14 @@ def inline_utils(self):
363
389
)
364
390
return res
365
391
366
- def filter_op (self , op : "CKGemmOperation" ):
392
+ def _has_padding (self , dimension , gemm_specialization ):
393
+ # Get the relevant padding map for the given dimension
394
+ dimension_padding = padding_lookup .get (dimension , {})
395
+
396
+ # Check if the specialization is in the dimension's padding map
397
+ return dimension_padding .get (gemm_specialization , False )
398
+
399
+ def filter_op (self , op_info : InductorROCmOp ):
367
400
"""
368
401
Determines whether a given op definition is suitable for the current
369
402
input / output of the operation that this template implements.
@@ -372,6 +405,7 @@ def filter_op(self, op: "CKGemmOperation"):
372
405
373
406
Returns None if the op is not suitable, otherwise returns the op to be used.
374
407
"""
408
+ op , kBatch = op_info .op , op_info .kBatch
375
409
metas = [T .get_layout () for T in [* self .input_nodes , self .output_node ]]
376
410
X_meta = metas [0 ]
377
411
W_meta = metas [1 ]
@@ -398,26 +432,27 @@ def filter_op(self, op: "CKGemmOperation"):
398
432
N = W_meta .size [- 1 ]
399
433
400
434
if is_static_int (M ):
401
- if not any (
402
- m_padding in op .gemm_specialization
403
- for m_padding in ["MPadding" , "MNPadding" , "MKPadding" , "MNKPadding" ]
404
- ):
435
+ if not self ._has_padding ("M" , op .gemm_specialization ):
405
436
if M % op .m_per_block != 0 :
406
437
return None
407
438
if is_static_int (N ):
408
- if not any (
409
- n_padding in op .gemm_specialization
410
- for n_padding in ["NPadding" , "MNPadding" , "NKPadding" , "MNKPadding" ]
411
- ):
439
+ if not self ._has_padding ("N" , op .gemm_specialization ):
412
440
if N % op .n_per_block != 0 :
413
441
return None
414
442
if is_static_int (K ):
415
- if not any (
416
- k_padding in op .gemm_specialization
417
- for k_padding in ["KPadding" , "MKPadding" , "NKPadding" , "MNKPadding" ]
418
- ):
443
+ if not self ._has_padding ("K" , op .gemm_specialization ):
419
444
if K % op .k_per_block != 0 :
420
445
return None
446
+ K_t = kBatch * op .k_per_block
447
+ if K % K_t != 0 :
448
+ return None
449
+ else :
450
+ # need another kBatch check here
451
+ lcm = abs (op .a_k1 * op .b_k1 ) // math .gcd (op .a_k1 , op .b_k1 )
452
+ K_t = kBatch * lcm
453
+ k_read_pad_splited = math .ceil (K / K_t ) * lcm
454
+ if (k_read_pad_splited * (kBatch - 1 )) >= K :
455
+ return None
421
456
422
457
a_contig_size = (
423
458
K if op .a_layout == "Row" else M if op .a_layout == "Col" else None
@@ -451,12 +486,83 @@ def filter_op(self, op: "CKGemmOperation"):
451
486
!= 0
452
487
):
453
488
return None
454
-
489
+ if not self ._check_num_k_loops (op , kBatch ):
490
+ return None
455
491
# TBD disable instances with invalid number of pipeline prefetch stages
456
492
# It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check
457
493
458
494
return op
459
495
496
+ def _check_num_k_loops (self , op , kBatch ):
497
+ # Additional splitK scenario check
498
+ metas = [T .get_layout () for T in [* self .input_nodes ]]
499
+ X_meta = metas [0 ]
500
+ W_meta = metas [1 ]
501
+ K = X_meta .size [- 1 ]
502
+ if kBatch > 1 :
503
+ if op .block_gemm_pipeline_version != "BlockGemmPipelineVersion::v1" :
504
+ try :
505
+ prefetch_stages = self ._prefetch_stages (
506
+ op ,
507
+ torch .empty ((), dtype = X_meta .dtype ).element_size (),
508
+ torch .empty ((), dtype = W_meta .dtype ).element_size (),
509
+ torch .cuda .get_device_properties (X_meta .device ).warp_size ,
510
+ )
511
+ except Exception as e :
512
+ log .debug (
513
+ "Failed to prefetch_stages for %s with exception %s" , op .name , e
514
+ )
515
+ # be conservative here and disable the op
516
+ return False
517
+
518
+ K_t = op .k_per_block * kBatch
519
+ ak0 = (K + K_t - 1 ) // K_t * (op .k_per_block // op .a_k1 )
520
+ num_k_loop = ak0 // (op .k_per_block // op .a_k1 )
521
+ if num_k_loop <= prefetch_stages :
522
+ log .debug (
523
+ "Op %s is not compatible due to invalid number of pipeline prefetch stages. "
524
+ "Parameters: kBatch=%s, block_gemm_pipeline_version=%s, prefetch_stages=%s, num_k_loop=%s" ,
525
+ op .name (),
526
+ kBatch ,
527
+ op .block_gemm_pipeline_version ,
528
+ prefetch_stages ,
529
+ num_k_loop ,
530
+ )
531
+ return False
532
+
533
+ return True
534
+
535
+ # small helper to figure out the prefetch stages on AMD
536
+ def _prefetch_stages (self , op , a_dtype_size , b_dtype_size , warp_size : int = 64 ):
537
+ version_str = op .block_gemm_pipeline_version .split ("::" )[- 1 ]
538
+ try :
539
+ version = int (version_str [1 :]) # Assuming the format is always 'vX'
540
+ except ValueError as e :
541
+ raise ValueError (f"Invalid version string: { version_str } " ) from e
542
+ if version not in [1 , 2 , 3 , 4 , 5 ]:
543
+ raise ValueError (
544
+ f"unknown prefetch stages for { op .block_gemm_pipeline_version } "
545
+ )
546
+ # Define the mapping of versions to stages
547
+ version_to_stages = {1 : 1 , 3 : 2 , 4 : 4 , 5 : 3 }
548
+ # Get the stages for the given version
549
+ stages = version_to_stages .get (version , None )
550
+ if stages is None :
551
+ # This means we're at stage 2, and this requires computation
552
+ # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950
553
+ wgp_per_cu = max (4 * warp_size // op .block_size , 1 )
554
+ full_mem_band_prefetch_stages = math .ceil (
555
+ 32768
556
+ / wgp_per_cu
557
+ / (
558
+ (op .m_per_block * a_dtype_size + op .n_per_block * b_dtype_size )
559
+ * op .k_per_block
560
+ )
561
+ )
562
+ stages = min (max (full_mem_band_prefetch_stages , 2 ), 8 )
563
+
564
+ return stages
565
+
460
566
def emit_ck_instance (self , op : "CKGemmOperation" ):
461
567
# The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance
462
568
struct_name = (
@@ -765,7 +871,7 @@ def _is_rcr_f16(self):
765
871
and Y_layout == "Row"
766
872
)
767
873
768
- def gen_ops (self ):
874
+ def gen_ops (self ) -> list [ InductorROCmOp ] :
769
875
"""
770
876
Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents.
771
877
The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments.
@@ -794,7 +900,17 @@ def gen_ops(self):
794
900
795
901
assert generator is not None
796
902
797
- filtered_instances = list (filter (lambda op : self .filter_op (op ), generator ()))
903
+ # NOTE(coconutruben): for now, we only support kBatch 1
904
+ # TODO(coconturuben): infer a better kBatch depending on the input shape
905
+ # TODO(coconutruben): allow users to provide a list of kBatches to sweep over
906
+ kBatches = [1 ]
907
+ rops = generator ()
908
+ ops = [
909
+ InductorROCmOp (op = op , kBatch = kBatch ) for op in rops for kBatch in kBatches
910
+ ]
911
+
912
+ filtered_instances = list (filter (lambda op : self .filter_op (op ), ops ))
913
+
798
914
# NB: when using a fixed list order, most likely we will pick the subset of instances
799
915
# which are very similar to each other. Randomizing the choice seems to solve this.
800
916
random .seed (- 11 )
@@ -836,8 +952,8 @@ def add_ck_gemm_choices(
836
952
for op in ops :
837
953
template .maybe_append_choice (
838
954
choices ,
839
- op = op ,
840
- kBatch = 1 ,
955
+ op = op . op ,
956
+ kBatch = op . kBatch ,
841
957
)
842
958
843
959
def size_args (self ):
0 commit comments