Skip to content

feat: paged attention v2 #1183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3

flash-attention-v2:
# Clone flash attention
Expand Down
4 changes: 2 additions & 2 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc

vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
git clone https://github.com/vllm-project/vllm.git

build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
# Flash attention imports
import dropout_layer_norm

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -269,7 +265,7 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)

Expand All @@ -279,7 +275,7 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
Expand All @@ -290,9 +286,7 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
Expand All @@ -301,7 +295,6 @@ def forward(
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
# Flash attention imports
import dropout_layer_norm

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
Expand Down Expand Up @@ -272,7 +269,7 @@ def forward(
else:
kv_to_cache = kv

vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)

Expand All @@ -282,7 +279,7 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
Expand All @@ -294,9 +291,7 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
Expand All @@ -305,7 +300,6 @@ def forward(
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
Expand Down Expand Up @@ -141,7 +138,7 @@ def forward(
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)

vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)

Expand All @@ -151,7 +148,7 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
Expand All @@ -162,9 +159,7 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
qkv[:, 0],
kv_cache[0],
Expand All @@ -173,7 +168,6 @@ def forward(
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
Expand Down Expand Up @@ -191,7 +188,7 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)

Expand All @@ -201,7 +198,7 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
Expand All @@ -212,9 +209,7 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
Expand All @@ -223,7 +218,6 @@ def forward(
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

Expand Down Expand Up @@ -310,7 +304,7 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)

vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
Expand All @@ -324,7 +318,7 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
Expand All @@ -335,9 +329,7 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
Expand All @@ -346,7 +338,6 @@ def forward(
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
Expand All @@ -18,7 +15,6 @@
FastLayerNorm,
get_linear,
)
from safetensors import SafetensorError


def load_multi_mqa(
Expand Down Expand Up @@ -258,7 +254,7 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)

vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)

Expand All @@ -268,7 +264,7 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
Expand All @@ -279,9 +275,7 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
Expand All @@ -290,7 +284,6 @@ def forward(
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

Expand Down
Loading