@@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor,
128
128
bias : Optional [torch .Tensor ] = None ,
129
129
block_size_m : int = 32 ,
130
130
block_size_n : int = 32 ,
131
- block_size_k : int = 32 ) -> torch .Tensor :
131
+ block_size_k : int = 32 ,
132
+ use_heuristic = True ) -> torch .Tensor :
132
133
M , K = input .shape
133
134
N = weight .shape [1 ]
134
135
@@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor,
152
153
153
154
has_scalar = lambda x : x .shape [0 ] == 1 and x .shape [1 ] == 1
154
155
156
+ if use_heuristic :
157
+ is_small_N = N < 8192
158
+ next_power_of_2_M = max (32 , triton .next_power_of_2 (M ))
159
+ if next_power_of_2_M <= 32 :
160
+ tile_shape = (64 , 64 , 256 ) if is_small_N else (64 , 128 , 256 )
161
+ elif next_power_of_2_M <= 64 :
162
+ tile_shape = (64 , 64 , 256 )
163
+ elif next_power_of_2_M <= 128 :
164
+ tile_shape = (64 , 128 , 128 )
165
+ else :
166
+ tile_shape = (128 , 128 , 128 )
167
+
168
+ block_size_m , block_size_n , block_size_k = tile_shape
169
+
155
170
block_size_sa = 1 if has_scalar (scale_a ) else block_size_m
156
171
block_size_sb = 1 if has_scalar (scale_b ) else block_size_n
157
172
0 commit comments