From 04c3d6cbe996e5e6a443e34b767720acec1e7f53 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 27 Jan 2025 11:10:38 +0800 Subject: [PATCH 1/2] fix pin memory Signed-off-by: youkaichao --- flashinfer/decode.py | 8 ++++++-- flashinfer/prefill.py | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 00da72fef..9a9c2d931 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -651,7 +651,8 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True + (8 * 1024 * 1024,), dtype=torch.uint8, + pin_memory=True, device="cpu", ) if use_cuda_graph: @@ -718,6 +719,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, + device="cpu", pin_memory=True, ) @@ -1277,7 +1279,8 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True + (8 * 1024 * 1024,), dtype=torch.uint8, + pin_memory=True, device="cpu", ) if use_cuda_graph: @@ -1330,6 +1333,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, + device="cpu", pin_memory=True, ) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index f9093017c..96b2aa79e 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1099,6 +1099,7 @@ def __init__( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, + device="cpu", pin_memory=True, ) self._use_cuda_graph = use_cuda_graph @@ -1165,6 +1166,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, + device="cpu", pin_memory=True, ) @@ -1858,7 +1860,8 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True + self._int_workspace_buffer.shape, dtype=torch.uint8, + pin_memory=True, device="cpu", ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: @@ -1911,6 +1914,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, + device="cpu", pin_memory=True, ) From 7ffad0739f464385299dc7e9bb7029e9d627e192 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 27 Jan 2025 11:15:24 +0800 Subject: [PATCH 2/2] fix various device Signed-off-by: youkaichao --- flashinfer/prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 96b2aa79e..d4239cdbe 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1417,7 +1417,7 @@ def plan( if page_size != 1: vector_sparse_indptr_host = torch.cat( [ - torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32, device=kv_lens_arr_host.device), torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), ], dim=0,