Skip to content

Commit 5f0ec39

Browse files
authored
[V1] Remove _get_cache_block_size (#12214)
Signed-off-by: Chen Zhang <[email protected]>
1 parent c222f47 commit 5f0ec39

File tree

1 file changed

+1
-23
lines changed

1 file changed

+1
-23
lines changed

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import torch.nn as nn
99

1010
import vllm.envs as envs
11-
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
11+
from vllm.config import ParallelConfig, VllmConfig
1212
from vllm.distributed import (ensure_model_parallel_initialized,
1313
init_distributed_environment,
1414
set_custom_all_reduce)
1515
from vllm.logger import init_logger
1616
from vllm.model_executor import set_random_seed
1717
from vllm.platforms import current_platform
18-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
1918
from vllm.v1.core.scheduler import SchedulerOutput
2019
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2120
from vllm.v1.outputs import ModelRunnerOutput
@@ -235,24 +234,3 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
235234
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
236235
"You can use float16 instead by explicitly setting the"
237236
"`dtype` flag in CLI, for example: --dtype=half.")
238-
239-
240-
def _get_cache_block_size(
241-
cache_config: CacheConfig,
242-
model_config: ModelConfig,
243-
parallel_config: ParallelConfig,
244-
) -> int:
245-
head_size = model_config.get_head_size()
246-
num_heads = model_config.get_num_kv_heads(parallel_config)
247-
num_attention_layers = model_config.get_num_layers_by_block_type(
248-
parallel_config, LayerBlockType.attention)
249-
250-
key_cache_block = cache_config.block_size * num_heads * head_size
251-
value_cache_block = key_cache_block
252-
total = num_attention_layers * (key_cache_block + value_cache_block)
253-
if cache_config.cache_dtype == "auto":
254-
dtype = model_config.dtype
255-
else:
256-
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
257-
dtype_size = get_dtype_size(dtype)
258-
return dtype_size * total

0 commit comments

Comments
 (0)