|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | + |
| 6 | +from vllm.model_executor.layers.lightning_attn import ( |
| 7 | + linear_decode_forward_triton) |
| 8 | +from vllm.platforms import current_platform |
| 9 | + |
| 10 | +NUM_HEADS = [4, 8] |
| 11 | +HEAD_SIZES = [64] |
| 12 | +BATCH_SIZES = [1, 2] |
| 13 | +SEQ_LENGTHS = [16] |
| 14 | +DTYPES = [torch.float32] |
| 15 | + |
| 16 | + |
| 17 | +def reference_lightning_attention(q, k, v, ed, block_size, kv_history): |
| 18 | + """Reference implementation of lightning attention core algorithm |
| 19 | + |
| 20 | + The difference from the main implementation is that this processes |
| 21 | + each step sequentially, instead of using parallelized triton kernels |
| 22 | + """ |
| 23 | + B, H, S, D = q.shape |
| 24 | + E = v.shape[-1] |
| 25 | + dtype = q.dtype |
| 26 | + output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device) |
| 27 | + |
| 28 | + # Use clone() to ensure an independent copy |
| 29 | + if kv_history is None: |
| 30 | + kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device) |
| 31 | + else: |
| 32 | + kv_cache = kv_history.clone() |
| 33 | + |
| 34 | + # More efficient implementation |
| 35 | + # Convert decay factors to matrix form |
| 36 | + if ed.dim() == 1: |
| 37 | + decay = torch.exp(-ed).view(1, -1, 1, 1) |
| 38 | + else: |
| 39 | + decay = torch.exp(-ed) |
| 40 | + |
| 41 | + for b in range(B): |
| 42 | + for step in range(S): |
| 43 | + # Process all heads at once for this position |
| 44 | + q_bs = q[b, :, step] # [H, D] |
| 45 | + k_bs = k[b, :, step] # [H, D] |
| 46 | + v_bs = v[b, :, step] # [H, E] |
| 47 | + |
| 48 | + # Calculate KV outer products for all heads |
| 49 | + for h in range(H): |
| 50 | + # Calculate KV outer product |
| 51 | + kv_outer = torch.outer(k_bs[h], v_bs[h]) |
| 52 | + |
| 53 | + # Update KV cache with decay |
| 54 | + # Note: Using the same order as in the Triton kernel |
| 55 | + kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer |
| 56 | + |
| 57 | + # Calculate attention output |
| 58 | + output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h]) |
| 59 | + |
| 60 | + # Match the shape returned by the actual implementation |
| 61 | + # The actual implementation returns a tensor of shape [B, H, 2, D, E] |
| 62 | + # where dimension 2 contains both KV and KV history |
| 63 | + kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E] |
| 64 | + final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], |
| 65 | + dim=2) # [B, H, 2, D, E] |
| 66 | + |
| 67 | + return output, final_kv_cache |
| 68 | + |
| 69 | + |
| 70 | +def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx): |
| 71 | + """Reference implementation: linear attention decode function""" |
| 72 | + B, H, _, D = q.shape |
| 73 | + output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device) |
| 74 | + |
| 75 | + # Calculate decay factors once (more efficient) |
| 76 | + decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1] |
| 77 | + |
| 78 | + # Process each batch |
| 79 | + for b in range(B): |
| 80 | + slot_id = slot_idx[b].item() |
| 81 | + |
| 82 | + # Skip padding positions |
| 83 | + if slot_id == -1: |
| 84 | + continue |
| 85 | + |
| 86 | + # Process all heads at once for this batch |
| 87 | + q_b = q[b, :, 0] # [H, D] |
| 88 | + k_b = k[b, :, 0] # [H, D] |
| 89 | + v_b = v[b, :, 0] # [H, D] |
| 90 | + |
| 91 | + # Process each attention head |
| 92 | + for h in range(H): |
| 93 | + # Get current query, key and value |
| 94 | + q_bh = q_b[h] |
| 95 | + k_bh = k_b[h] |
| 96 | + v_bh = v_b[h] |
| 97 | + |
| 98 | + # Get cache |
| 99 | + kv_cache_old = kv_caches[b, h] |
| 100 | + |
| 101 | + # Calculate new key-value outer product |
| 102 | + kv_outer = torch.outer(k_bh, v_bh) |
| 103 | + |
| 104 | + # Apply decay and update cache |
| 105 | + kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old |
| 106 | + |
| 107 | + # Calculate output |
| 108 | + out_h = torch.matmul(q_bh, kv_new) |
| 109 | + |
| 110 | + # Update output and cache |
| 111 | + output[b, h * D:(h + 1) * D] = out_h |
| 112 | + kv_caches[b, h] = kv_new |
| 113 | + |
| 114 | + return output |
| 115 | + |
| 116 | + |
| 117 | +@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
| 118 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 119 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 120 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 121 | +@torch.inference_mode() |
| 122 | +def test_linear_decode_forward_triton( |
| 123 | + batch_size: int, |
| 124 | + num_heads: int, |
| 125 | + head_size: int, |
| 126 | + dtype: torch.dtype, |
| 127 | +): |
| 128 | + torch.set_default_device("cuda") |
| 129 | + torch.manual_seed(42) |
| 130 | + torch.cuda.manual_seed_all(42) |
| 131 | + current_platform.seed_everything(42) |
| 132 | + base = 0.01 |
| 133 | + q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) |
| 134 | + k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) |
| 135 | + v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) |
| 136 | + |
| 137 | + kv_caches = base * torch.randn(batch_size, |
| 138 | + num_heads, |
| 139 | + head_size, |
| 140 | + head_size, |
| 141 | + dtype=dtype, |
| 142 | + device="cuda") |
| 143 | + |
| 144 | + kv_caches_copy = kv_caches.clone() |
| 145 | + |
| 146 | + slope_rate = torch.zeros(num_heads, device="cuda") |
| 147 | + for h in range(num_heads): |
| 148 | + slope_rate[h] = 0.1 * (h + 1) |
| 149 | + |
| 150 | + slot_idx = torch.arange(batch_size, device="cuda") |
| 151 | + |
| 152 | + triton_output = linear_decode_forward_triton(q, k, v, kv_caches, |
| 153 | + slope_rate, slot_idx) |
| 154 | + |
| 155 | + reference_output = reference_linear_decode(q, k, v, kv_caches_copy, |
| 156 | + slope_rate, slot_idx) |
| 157 | + torch.testing.assert_close(triton_output, |
| 158 | + reference_output, |
| 159 | + rtol=1e-1, |
| 160 | + atol=1e-1) |
| 161 | + torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1) |
| 162 | + |
| 163 | + assert triton_output.shape == (batch_size, num_heads * head_size) |
| 164 | + |
| 165 | + |
| 166 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 167 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 168 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 169 | +@torch.inference_mode() |
| 170 | +def test_linear_decode_forward_triton_with_padding( |
| 171 | + num_heads: int, |
| 172 | + head_size: int, |
| 173 | + dtype: torch.dtype, |
| 174 | +): |
| 175 | + torch.set_default_device("cuda") |
| 176 | + torch.manual_seed(42) |
| 177 | + torch.cuda.manual_seed_all(42) |
| 178 | + current_platform.seed_everything(42) |
| 179 | + |
| 180 | + batch_size = 4 |
| 181 | + base = 0.01 |
| 182 | + q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) |
| 183 | + k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) |
| 184 | + v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) |
| 185 | + |
| 186 | + kv_caches = base * torch.randn(batch_size, |
| 187 | + num_heads, |
| 188 | + head_size, |
| 189 | + head_size, |
| 190 | + dtype=dtype, |
| 191 | + device="cuda") |
| 192 | + |
| 193 | + kv_caches_copy = kv_caches.clone() |
| 194 | + |
| 195 | + slope_rate = torch.zeros(num_heads, device="cuda") |
| 196 | + for h in range(num_heads): |
| 197 | + slope_rate[h] = 0.1 * (h + 1) |
| 198 | + |
| 199 | + slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") |
| 200 | + |
| 201 | + triton_output = linear_decode_forward_triton(q, k, v, kv_caches, |
| 202 | + slope_rate, slot_idx) |
| 203 | + |
| 204 | + reference_output = reference_linear_decode(q, k, v, kv_caches_copy, |
| 205 | + slope_rate, slot_idx) |
| 206 | + |
| 207 | + padding_mask = (slot_idx |
| 208 | + != -1).unsqueeze(1).expand(-1, num_heads * head_size) |
| 209 | + |
| 210 | + triton_masked = triton_output[padding_mask] |
| 211 | + reference_masked = reference_output[padding_mask] |
| 212 | + |
| 213 | + atol, rtol = 1.5e-1, 1.5e-1 |
| 214 | + |
| 215 | + valid_indices = slot_idx != -1 |
| 216 | + |
| 217 | + for i in range(batch_size): |
| 218 | + if valid_indices[i] > 0: |
| 219 | + torch.testing.assert_close(kv_caches[i], |
| 220 | + kv_caches_copy[i], |
| 221 | + rtol=rtol, |
| 222 | + atol=atol) |
| 223 | + |
| 224 | + torch.testing.assert_close(triton_masked, |
| 225 | + reference_masked, |
| 226 | + rtol=rtol, |
| 227 | + atol=atol) |
| 228 | + |
| 229 | + assert triton_output.shape == (batch_size, num_heads * head_size) |
| 230 | + |
| 231 | + |
| 232 | +@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
| 233 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 234 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 235 | +@pytest.mark.parametrize("seq_len", SEQ_LENGTHS) |
| 236 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 237 | +@torch.inference_mode() |
| 238 | +def test_lightning_attention_reference( |
| 239 | + batch_size: int, |
| 240 | + num_heads: int, |
| 241 | + head_size: int, |
| 242 | + seq_len: int, |
| 243 | + dtype: torch.dtype, |
| 244 | +): |
| 245 | + torch.set_default_device("cuda") |
| 246 | + torch.manual_seed(42) |
| 247 | + torch.cuda.manual_seed_all(42) |
| 248 | + current_platform.seed_everything(42) |
| 249 | + |
| 250 | + base = 0.01 |
| 251 | + q = base * torch.randn( |
| 252 | + batch_size, num_heads, seq_len, head_size, dtype=dtype) |
| 253 | + k = base * torch.randn( |
| 254 | + batch_size, num_heads, seq_len, head_size, dtype=dtype) |
| 255 | + v = base * torch.randn( |
| 256 | + batch_size, num_heads, seq_len, head_size, dtype=dtype) |
| 257 | + |
| 258 | + ed = torch.zeros(num_heads, device="cuda") |
| 259 | + for h in range(num_heads): |
| 260 | + ed[h] = 0.1 * (h + 1) |
| 261 | + |
| 262 | + kv_history = base * torch.randn(batch_size, |
| 263 | + num_heads, |
| 264 | + head_size, |
| 265 | + head_size, |
| 266 | + dtype=dtype, |
| 267 | + device="cuda") |
| 268 | + |
| 269 | + kv_history_clone = kv_history.clone() |
| 270 | + |
| 271 | + ref_output, ref_kv_cache = reference_lightning_attention( |
| 272 | + q, k, v, ed, 256, kv_history) |
| 273 | + |
| 274 | + from vllm.model_executor.layers.lightning_attn import lightning_attention |
| 275 | + actual_output, actual_kv_cache = lightning_attention( |
| 276 | + q, k, v, ed, 256, kv_history_clone) |
| 277 | + |
| 278 | + atol, rtol = 1.5e-1, 1.5e-1 |
| 279 | + torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol) |
| 280 | + torch.testing.assert_close(ref_kv_cache, |
| 281 | + actual_kv_cache, |
| 282 | + rtol=rtol, |
| 283 | + atol=atol) |
| 284 | + |
| 285 | + assert ref_output.shape == (batch_size, num_heads, seq_len, head_size) |
| 286 | + assert ref_kv_cache.shape == actual_kv_cache.shape |
0 commit comments