|
| 1 | +from typing import List, Optional, Tuple |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | + |
| 6 | +from vllm.platforms import current_platform |
| 7 | +from vllm.v1.attention.backends.flash_attn import (cascade_attention, |
| 8 | + merge_attn_states) |
| 9 | +from vllm.vllm_flash_attn import flash_attn_varlen_func |
| 10 | + |
| 11 | +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] |
| 12 | +HEAD_SIZES = [128, 192, 256] |
| 13 | +BLOCK_SIZES = [16] |
| 14 | +DTYPES = [torch.float16, torch.bfloat16] |
| 15 | + |
| 16 | + |
| 17 | +@pytest.mark.parametrize("num_tokens", [1, 39, 16912]) |
| 18 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 19 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 20 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 21 | +@torch.inference_mode() |
| 22 | +def test_merge_kernel( |
| 23 | + num_tokens: int, |
| 24 | + num_heads: Tuple[int, int], |
| 25 | + head_size: int, |
| 26 | + dtype: torch.dtype, |
| 27 | +): |
| 28 | + torch.set_default_device("cuda") |
| 29 | + current_platform.seed_everything(0) |
| 30 | + num_query_heads = num_heads[0] |
| 31 | + num_kv_heads = num_heads[1] |
| 32 | + assert num_query_heads % num_kv_heads == 0 |
| 33 | + |
| 34 | + # Prepare inputs. |
| 35 | + prefix_output = torch.randn(num_tokens, |
| 36 | + num_query_heads, |
| 37 | + head_size, |
| 38 | + dtype=dtype) |
| 39 | + suffix_output = torch.randn(num_tokens, |
| 40 | + num_query_heads, |
| 41 | + head_size, |
| 42 | + dtype=dtype) |
| 43 | + prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) |
| 44 | + suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) |
| 45 | + |
| 46 | + # Run the kernel. |
| 47 | + output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype) |
| 48 | + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, |
| 49 | + suffix_lse) |
| 50 | + |
| 51 | + # Reference implementation. |
| 52 | + max_lse = torch.maximum(prefix_lse, suffix_lse) |
| 53 | + p_lse = torch.exp(prefix_lse - max_lse) |
| 54 | + s_lse = torch.exp(suffix_lse - max_lse) |
| 55 | + p_scale = p_lse / (p_lse + s_lse) |
| 56 | + s_scale = s_lse / (p_lse + s_lse) |
| 57 | + p_scale = p_scale.transpose(0, 1).unsqueeze(2) |
| 58 | + s_scale = s_scale.transpose(0, 1).unsqueeze(2) |
| 59 | + ref_output = p_scale * prefix_output + s_scale * suffix_output |
| 60 | + ref_output = ref_output.to(dtype) |
| 61 | + |
| 62 | + # Compare the results. |
| 63 | + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) |
| 64 | + |
| 65 | + |
| 66 | +CASES = [ |
| 67 | + # Case 1. A general case. |
| 68 | + ([(129, 871), (18, 280), (37, 988), (1023, 2304), (1, 257)], 256), |
| 69 | + # Case 2. Flash-decoding case. |
| 70 | + ([(1, 1023), (1, 879), (1, 778), (1, 1777)] * 100, 512), |
| 71 | +] |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.parametrize("seq_lens_and_common_prefix", CASES) |
| 75 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 76 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 77 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 78 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| 79 | +@pytest.mark.parametrize("soft_cap", [None, 50]) |
| 80 | +@pytest.mark.parametrize("num_blocks", [2048]) |
| 81 | +@torch.inference_mode() |
| 82 | +def test_cascade( |
| 83 | + seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int], |
| 84 | + num_heads: Tuple[int, int], |
| 85 | + head_size: int, |
| 86 | + dtype: torch.dtype, |
| 87 | + block_size: int, |
| 88 | + soft_cap: Optional[float], |
| 89 | + num_blocks: int, |
| 90 | +) -> None: |
| 91 | + torch.set_default_device("cuda") |
| 92 | + current_platform.seed_everything(0) |
| 93 | + |
| 94 | + window_size = (-1, -1) |
| 95 | + scale = head_size**-0.5 |
| 96 | + num_query_heads = num_heads[0] |
| 97 | + num_kv_heads = num_heads[1] |
| 98 | + assert num_query_heads % num_kv_heads == 0 |
| 99 | + key_cache = torch.randn(num_blocks, |
| 100 | + block_size, |
| 101 | + num_kv_heads, |
| 102 | + head_size, |
| 103 | + dtype=dtype) |
| 104 | + value_cache = torch.randn_like(key_cache) |
| 105 | + |
| 106 | + seq_lens, common_prefix_len = seq_lens_and_common_prefix |
| 107 | + num_seqs = len(seq_lens) |
| 108 | + query_lens = [x[0] for x in seq_lens] |
| 109 | + kv_lens = [x[1] for x in seq_lens] |
| 110 | + max_query_len = max(query_lens) |
| 111 | + max_kv_len = max(kv_lens) |
| 112 | + |
| 113 | + total_num_query_tokens = sum(query_lens) |
| 114 | + query = torch.randn(total_num_query_tokens, |
| 115 | + num_query_heads, |
| 116 | + head_size, |
| 117 | + dtype=dtype) |
| 118 | + cu_query_lens = torch.tensor([0] + query_lens, |
| 119 | + dtype=torch.int32).cumsum(dim=0, |
| 120 | + dtype=torch.int32) |
| 121 | + cu_kv_lens = torch.tensor([0] + kv_lens, |
| 122 | + dtype=torch.int32).cumsum(dim=0, |
| 123 | + dtype=torch.int32) |
| 124 | + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size |
| 125 | + block_tables = torch.randint(0, |
| 126 | + num_blocks, |
| 127 | + (num_seqs, max_num_blocks_per_seq), |
| 128 | + dtype=torch.int32) |
| 129 | + |
| 130 | + assert common_prefix_len > 0 |
| 131 | + assert common_prefix_len % block_size == 0 |
| 132 | + num_common_kv_blocks = common_prefix_len // block_size |
| 133 | + # Make sure the first `num_common_kv_blocks` blocks are the same. |
| 134 | + block_tables[:, :num_common_kv_blocks] = \ |
| 135 | + block_tables[0, :num_common_kv_blocks] |
| 136 | + |
| 137 | + # Run the regular attention. |
| 138 | + ref_output = flash_attn_varlen_func( |
| 139 | + q=query, |
| 140 | + k=key_cache, |
| 141 | + v=value_cache, |
| 142 | + cu_seqlens_q=cu_query_lens, |
| 143 | + cu_seqlens_k=cu_kv_lens, |
| 144 | + max_seqlen_q=max_query_len, |
| 145 | + max_seqlen_k=max_kv_len, |
| 146 | + softmax_scale=scale, |
| 147 | + causal=True, |
| 148 | + window_size=window_size, |
| 149 | + block_table=block_tables, |
| 150 | + softcap=soft_cap if soft_cap is not None else 0, |
| 151 | + ) |
| 152 | + |
| 153 | + # Run cascade attention. |
| 154 | + assert all(common_prefix_len < kv_len for kv_len in kv_lens) |
| 155 | + cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], |
| 156 | + dtype=torch.int32) |
| 157 | + cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32) |
| 158 | + cu_suffix_kv_lens = ( |
| 159 | + cu_kv_lens - |
| 160 | + torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len) |
| 161 | + output = torch.empty_like(query) |
| 162 | + cascade_attention( |
| 163 | + output=output, |
| 164 | + query=query, |
| 165 | + key_cache=key_cache, |
| 166 | + value_cache=value_cache, |
| 167 | + cu_query_lens=cu_query_lens, |
| 168 | + max_query_len=max_query_len, |
| 169 | + cu_prefix_query_lens=cu_prefix_query_lens, |
| 170 | + cu_prefix_kv_lens=cu_prefix_kv_lens, |
| 171 | + cu_suffix_kv_lens=cu_suffix_kv_lens, |
| 172 | + max_kv_len=max_kv_len, |
| 173 | + softmax_scale=scale, |
| 174 | + alibi_slopes=None, |
| 175 | + sliding_window=window_size, |
| 176 | + logits_soft_cap=soft_cap if soft_cap is not None else 0, |
| 177 | + block_table=block_tables, |
| 178 | + common_prefix_len=common_prefix_len, |
| 179 | + ) |
| 180 | + |
| 181 | + # Compare the results. |
| 182 | + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) |
0 commit comments