|
| 1 | +""" |
| 2 | +Test: |
| 3 | +
|
| 4 | +* Tests for MultiHeadAttention layer |
| 5 | +""" |
| 6 | +from unittest.mock import patch |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | + |
| 11 | +from vllm.attention.layer import MultiHeadAttention |
| 12 | +from vllm.attention.selector import _Backend, _cached_get_attn_backend |
| 13 | +from vllm.platforms import current_platform |
| 14 | +from vllm.platforms.cpu import CpuPlatform |
| 15 | +from vllm.platforms.cuda import CudaPlatform |
| 16 | +from vllm.platforms.rocm import RocmPlatform |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture(autouse=True) |
| 20 | +def clear_cache(): |
| 21 | + """Clear lru cache to ensure each test case runs without caching. |
| 22 | + """ |
| 23 | + _cached_get_attn_backend.cache_clear() |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) |
| 27 | +def test_mha_attn_platform(device: str): |
| 28 | + """ |
| 29 | + Test that the attention selector between different platform and device. |
| 30 | + """ |
| 31 | + torch.set_default_dtype(torch.float16) |
| 32 | + |
| 33 | + if device == "cpu": |
| 34 | + with patch("vllm.attention.selector.current_platform", CpuPlatform()): |
| 35 | + attn = MultiHeadAttention(16, 64, scale=1) |
| 36 | + assert attn.attn_backend == _Backend.TORCH_SDPA |
| 37 | + elif device == "hip": |
| 38 | + with patch("vllm.attention.selector.current_platform", RocmPlatform()): |
| 39 | + attn = MultiHeadAttention(16, 64, scale=1) |
| 40 | + assert attn.attn_backend == _Backend.TORCH_SDPA |
| 41 | + else: |
| 42 | + with patch("vllm.attention.selector.current_platform", CudaPlatform()): |
| 43 | + attn = MultiHeadAttention(16, 64, scale=1) |
| 44 | + assert attn.attn_backend == _Backend.FLASH_ATTN |
| 45 | + |
| 46 | + with patch("vllm.attention.selector.current_platform", CudaPlatform()): |
| 47 | + attn = MultiHeadAttention(16, 72, scale=1) |
| 48 | + assert attn.attn_backend == _Backend.XFORMERS |
| 49 | + |
| 50 | + |
| 51 | +def ref_attention( |
| 52 | + query: torch.Tensor, |
| 53 | + key: torch.Tensor, |
| 54 | + value: torch.Tensor, |
| 55 | + scale: float, |
| 56 | +) -> torch.Tensor: |
| 57 | + """ |
| 58 | + Native implementation of scaled dot product attention without mask: |
| 59 | + - query, key, value: [batch_size, seq_len, num_heads, head_size] |
| 60 | + - attn_mask: [batch_size, seq_len, seq_len] |
| 61 | + """ |
| 62 | + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) |
| 63 | + attn_weights = scale * torch.matmul(query, key.transpose(2, 3)) |
| 64 | + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) |
| 65 | + out = torch.matmul(attn_weights, value).transpose(1, 2) |
| 66 | + return out |
| 67 | + |
| 68 | + |
| 69 | +BATCH_SIZES = [1, 16] |
| 70 | +SEQ_LENS = [1] |
| 71 | +NUM_HEADS = [1, 16] |
| 72 | +NUM_KV_HEADS = [1] |
| 73 | +HEAD_SIZES = [64, 80] |
| 74 | +# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} |
| 75 | +DTYPES = [ |
| 76 | + torch.half, torch.bfloat16, torch.float |
| 77 | +] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] |
| 78 | +CUDA_DEVICES = ["cuda"] |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
| 82 | +@pytest.mark.parametrize("seq_len", SEQ_LENS) |
| 83 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 84 | +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) |
| 85 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 86 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 87 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 88 | +def test_mha_attn_forward( |
| 89 | + batch_size: int, |
| 90 | + seq_len: int, |
| 91 | + num_heads: int, |
| 92 | + num_kv_heads: int, |
| 93 | + head_size: int, |
| 94 | + dtype: torch.dtype, |
| 95 | + device: str, |
| 96 | +): |
| 97 | + current_platform.seed_everything(0) |
| 98 | + torch.set_default_device(device) |
| 99 | + torch.set_default_dtype(dtype) |
| 100 | + |
| 101 | + q = torch.randn(batch_size, seq_len, num_heads * head_size) |
| 102 | + k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) |
| 103 | + v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) |
| 104 | + scale = 1.0 / head_size**0.5 |
| 105 | + attn = MultiHeadAttention(num_heads, |
| 106 | + head_size, |
| 107 | + scale=scale, |
| 108 | + num_kv_heads=num_kv_heads) |
| 109 | + output = attn(q, k, v) |
| 110 | + |
| 111 | + assert num_heads % num_kv_heads == 0 |
| 112 | + num_queries_per_kv = num_heads // num_kv_heads |
| 113 | + q = q.reshape(batch_size, seq_len, num_heads, head_size) |
| 114 | + k = k.reshape(batch_size, seq_len, num_kv_heads, head_size) |
| 115 | + v = v.reshape(batch_size, seq_len, num_kv_heads, head_size) |
| 116 | + if num_queries_per_kv > 1: |
| 117 | + k = torch.repeat_interleave(k, num_queries_per_kv, dim=2) |
| 118 | + v = torch.repeat_interleave(v, num_queries_per_kv, dim=2) |
| 119 | + |
| 120 | + ref_output = ref_attention( |
| 121 | + q, |
| 122 | + k, |
| 123 | + v, |
| 124 | + scale=scale, |
| 125 | + ).reshape(batch_size, seq_len, num_heads * head_size) |
| 126 | + torch.testing.assert_close(output, ref_output) |
0 commit comments