|
| 1 | +from typing import Optional, Type |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | + |
| 8 | +def is_weak_contiguous(x: torch.Tensor): |
| 9 | + strides = x.stride() |
| 10 | + sizes = x.shape |
| 11 | + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) |
| 12 | + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) |
| 13 | + return is_transpose or is_not_transpose |
| 14 | + |
| 15 | + |
| 16 | +@triton.jit |
| 17 | +def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, |
| 18 | + M, N, K, stride_am, stride_ak, stride_bk, stride_bn, |
| 19 | + stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr, |
| 20 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, |
| 21 | + BLOCK_SIZE_K: tl.constexpr, |
| 22 | + BLOCK_SIZE_SCALE_A: tl.constexpr, |
| 23 | + BLOCK_SIZE_SCALE_B: tl.constexpr): |
| 24 | + pid = tl.program_id(axis=0) |
| 25 | + |
| 26 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 27 | + |
| 28 | + pid_m = pid // num_pid_n |
| 29 | + pid_n = pid % num_pid_n |
| 30 | + |
| 31 | + accumulator_dtype = ACCUMULATOR_DTYPE |
| 32 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), |
| 33 | + dtype=accumulator_dtype) |
| 34 | + |
| 35 | + # NOTE: Some tensor inputs are so large, they will cause int32 overflow |
| 36 | + # so it is necessary to use tl.int64 for all the offsets, else SEGV will |
| 37 | + # eventually occur. |
| 38 | + |
| 39 | + # Offsets and masks. |
| 40 | + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) |
| 41 | + masks_am = offsets_am < M |
| 42 | + |
| 43 | + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) |
| 44 | + masks_bn = offsets_bn < N |
| 45 | + |
| 46 | + offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) |
| 47 | + offsets_a = (stride_am * offsets_am[:, None] + |
| 48 | + stride_ak * offsets_k[None, :]) |
| 49 | + offsets_b = (stride_bk * offsets_k[:, None] + |
| 50 | + stride_bn * offsets_bn[None, :]) |
| 51 | + |
| 52 | + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create |
| 53 | + # appropriate offsets and masks for each case. Same goes for |
| 54 | + # BLOCK_SIZE_SCALE_B. |
| 55 | + offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + |
| 56 | + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) |
| 57 | + masks_scale_am = offsets_scale_am < M |
| 58 | + |
| 59 | + offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + |
| 60 | + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) |
| 61 | + masks_scale_bn = offsets_scale_bn < N |
| 62 | + |
| 63 | + a_ptrs = a_ptr + offsets_a |
| 64 | + b_ptrs = b_ptr + offsets_b |
| 65 | + |
| 66 | + scale_a_ptrs = scale_a_ptr + offsets_scale_am |
| 67 | + scale_b_ptrs = scale_b_ptr + offsets_scale_bn |
| 68 | + |
| 69 | + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 70 | + masks_k = offsets_k < K |
| 71 | + masks_a = masks_am[:, None] & masks_k[None, :] |
| 72 | + a = tl.load(a_ptrs, mask=masks_a) |
| 73 | + |
| 74 | + masks_b = masks_k[:, None] & masks_bn[None, :] |
| 75 | + b = tl.load(b_ptrs, mask=masks_b) |
| 76 | + |
| 77 | + # Accumulate results. |
| 78 | + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) |
| 79 | + |
| 80 | + offsets_k += BLOCK_SIZE_K |
| 81 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 82 | + b_ptrs += BLOCK_SIZE_K * stride_bk |
| 83 | + |
| 84 | + # Apply scale at end. |
| 85 | + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] |
| 86 | + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) |
| 87 | + # Need to broadcast to the appropriate size, if scale_a is already |
| 88 | + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes |
| 89 | + # for scale_b below. |
| 90 | + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) |
| 91 | + accumulator = scale_a * accumulator.to(tl.float32) |
| 92 | + |
| 93 | + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] |
| 94 | + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) |
| 95 | + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) |
| 96 | + accumulator = scale_b.T * accumulator.to(tl.float32) |
| 97 | + |
| 98 | + # Convert to output format. |
| 99 | + c = accumulator.to(c_ptr.type.element_ty) |
| 100 | + |
| 101 | + # Add bias, it's already in output format, so add it after conversion. |
| 102 | + if bias_ptr: |
| 103 | + offsets_bias = offsets_bn |
| 104 | + bias_ptrs = bias_ptr + offsets_bias |
| 105 | + bias_mask = offsets_bias < N |
| 106 | + bias = tl.load(bias_ptrs, bias_mask) |
| 107 | + c += bias |
| 108 | + |
| 109 | + # Save output |
| 110 | + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) |
| 111 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) |
| 112 | + offs_cm = offs_cm.to(tl.int64) |
| 113 | + offs_cn = offs_cn.to(tl.int64) |
| 114 | + c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + |
| 115 | + stride_cn * offs_cn[None, :]) |
| 116 | + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 117 | + |
| 118 | + tl.store(c_ptrs, c, mask=c_mask) |
| 119 | + |
| 120 | + |
| 121 | +# input - [M, K] |
| 122 | +# weight - [K, N] |
| 123 | +def triton_scaled_mm(input: torch.Tensor, |
| 124 | + weight: torch.Tensor, |
| 125 | + scale_a: torch.Tensor, |
| 126 | + scale_b: torch.Tensor, |
| 127 | + out_dtype: Type[torch.dtype], |
| 128 | + bias: Optional[torch.Tensor] = None, |
| 129 | + block_size_m: int = 32, |
| 130 | + block_size_n: int = 32, |
| 131 | + block_size_k: int = 32) -> torch.Tensor: |
| 132 | + M, K = input.shape |
| 133 | + N = weight.shape[1] |
| 134 | + |
| 135 | + assert N > 0 and K > 0 and M > 0 |
| 136 | + assert weight.shape[0] == K |
| 137 | + assert input.dtype == weight.dtype |
| 138 | + assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() |
| 139 | + assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( |
| 140 | + [M, 1]) |
| 141 | + assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( |
| 142 | + [N, 1]) |
| 143 | + assert out_dtype.is_floating_point |
| 144 | + assert bias is None or bias.is_floating_point() |
| 145 | + assert is_weak_contiguous(input) |
| 146 | + assert is_weak_contiguous(weight) |
| 147 | + |
| 148 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( |
| 149 | + N, META['BLOCK_SIZE_N']), ) |
| 150 | + |
| 151 | + result = torch.empty((M, N), dtype=out_dtype, device=input.device) |
| 152 | + |
| 153 | + has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 |
| 154 | + |
| 155 | + block_size_sa = 1 if has_scalar(scale_a) else block_size_m |
| 156 | + block_size_sb = 1 if has_scalar(scale_b) else block_size_n |
| 157 | + |
| 158 | + accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 |
| 159 | + |
| 160 | + # A = input, B = weight, C = result |
| 161 | + # A = M x K, B = K x N, C = M x N |
| 162 | + scaled_mm_kernel[grid](input, |
| 163 | + weight, |
| 164 | + scale_a, |
| 165 | + scale_b, |
| 166 | + result, |
| 167 | + bias, |
| 168 | + M, |
| 169 | + N, |
| 170 | + K, |
| 171 | + input.stride(0), |
| 172 | + input.stride(1), |
| 173 | + weight.stride(0), |
| 174 | + weight.stride(1), |
| 175 | + result.stride(0), |
| 176 | + result.stride(1), |
| 177 | + accumulator_dtype, |
| 178 | + BLOCK_SIZE_M=block_size_m, |
| 179 | + BLOCK_SIZE_N=block_size_n, |
| 180 | + BLOCK_SIZE_K=block_size_k, |
| 181 | + BLOCK_SIZE_SCALE_A=block_size_sa, |
| 182 | + BLOCK_SIZE_SCALE_B=block_size_sb) |
| 183 | + |
| 184 | + return result.to(out_dtype) |
0 commit comments