Skip to content

Commit 79e4937

Browse files
vanbasten23mgoin
andauthored
[v1] Add comments to the new ragged paged attention Pallas kernel (#14155)
Signed-off-by: Xiongfei Wei <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent cd1d3c3 commit 79e4937

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AttentionLayer, AttentionType)
1212
from vllm.attention.backends.utils import CommonAttentionState
1313

14+
# These are the 2 tunable parameters of the paged attention Pallas kernel.
1415
NUM_QUERIES_PER_BLOCK = 16
1516
NUM_KV_PAGES_PER_BLOCK = 128
1617

@@ -154,6 +155,9 @@ def forward(
154155
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
155156

156157
query = query * self.scale
158+
# use_kernel switches between using kernel or reference implementation
159+
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
160+
use_kernel = False
157161
output = torch.ops.xla.ragged_paged_attention(
158162
query,
159163
key_cache,
@@ -164,7 +168,7 @@ def forward(
164168
attn_metadata.num_seqs,
165169
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
166170
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
167-
use_kernel=False,
171+
use_kernel=use_kernel,
168172
)
169173

170174
return output.reshape(num_tokens, hidden_size)

0 commit comments

Comments
 (0)