@@ -1890,12 +1890,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
1890
1890
self .initialize_attn_backend (kv_cache_config )
1891
1891
1892
1892
kv_caches : dict [str , torch .Tensor ] = {}
1893
- kv_cache_group_ids : dict [str , int ] = {}
1894
1893
1895
- for id , kv_cache_group in enumerate (kv_cache_config .kv_cache_groups ):
1894
+ for i , kv_cache_group in enumerate (kv_cache_config .kv_cache_groups ):
1896
1895
kv_cache_spec = kv_cache_group .kv_cache_spec
1897
1896
for layer_name in kv_cache_group .layer_names :
1898
- kv_cache_group_ids [layer_name ] = id
1899
1897
tensor_config = kv_cache_config .tensors [layer_name ]
1900
1898
assert tensor_config .size % kv_cache_spec .page_size_bytes == 0
1901
1899
num_blocks = tensor_config .size // kv_cache_spec .page_size_bytes
@@ -1922,12 +1920,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
1922
1920
1923
1921
if self .speculative_config and self .speculative_config .use_eagle ():
1924
1922
assert isinstance (self .drafter , EagleProposer )
1925
- assert len (
1926
- set ([
1927
- kv_cache_group_ids [layer_name ]
1928
- for layer_name in self .drafter .attn_layer_names
1929
- ])) == 1 , "For multi-layer eagle draft model, "
1930
- "all layers should belong to the same kv cache group"
1923
+ # validate all draft model layers belong to the same kv cache
1924
+ # group
1925
+ self .drafter .validate_kv_cache_group (kv_cache_config )
1931
1926
1932
1927
bind_kv_cache (
1933
1928
kv_caches ,
0 commit comments