|
8 | 8 | import torch.nn as nn
|
9 | 9 |
|
10 | 10 | import vllm.envs as envs
|
11 |
| -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig |
| 11 | +from vllm.config import ParallelConfig, VllmConfig |
12 | 12 | from vllm.distributed import (ensure_model_parallel_initialized,
|
13 | 13 | init_distributed_environment,
|
14 | 14 | set_custom_all_reduce)
|
15 | 15 | from vllm.logger import init_logger
|
16 | 16 | from vllm.model_executor import set_random_seed
|
17 | 17 | from vllm.platforms import current_platform
|
18 |
| -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size |
19 | 18 | from vllm.v1.core.scheduler import SchedulerOutput
|
20 | 19 | from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
21 | 20 | from vllm.v1.outputs import ModelRunnerOutput
|
@@ -235,24 +234,3 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
235 | 234 | f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
236 | 235 | "You can use float16 instead by explicitly setting the"
|
237 | 236 | "`dtype` flag in CLI, for example: --dtype=half.")
|
238 |
| - |
239 |
| - |
240 |
| -def _get_cache_block_size( |
241 |
| - cache_config: CacheConfig, |
242 |
| - model_config: ModelConfig, |
243 |
| - parallel_config: ParallelConfig, |
244 |
| -) -> int: |
245 |
| - head_size = model_config.get_head_size() |
246 |
| - num_heads = model_config.get_num_kv_heads(parallel_config) |
247 |
| - num_attention_layers = model_config.get_num_layers_by_block_type( |
248 |
| - parallel_config, LayerBlockType.attention) |
249 |
| - |
250 |
| - key_cache_block = cache_config.block_size * num_heads * head_size |
251 |
| - value_cache_block = key_cache_block |
252 |
| - total = num_attention_layers * (key_cache_block + value_cache_block) |
253 |
| - if cache_config.cache_dtype == "auto": |
254 |
| - dtype = model_config.dtype |
255 |
| - else: |
256 |
| - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] |
257 |
| - dtype_size = get_dtype_size(dtype) |
258 |
| - return dtype_size * total |
0 commit comments