Skip to content

Commit 5d028dd

Browse files
ajayvohra2005Mu Huai
authored and
Mu Huai
committed
update neuron config (vllm-project#16289)
Signed-off-by: Ajay Vohra <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 2c19c03 commit 5d028dd

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

vllm/model_executor/model_loader/neuron.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,39 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
174174

175175
def _get_neuron_config_after_override(default_neuron_config,
176176
overridden_neuron_config):
177-
from transformers_neuronx.config import NeuronConfig
177+
from transformers_neuronx.config import (ContinuousBatchingConfig,
178+
GenerationConfig,
179+
KVCacheQuantizationConfig,
180+
NeuronConfig, QuantizationConfig,
181+
SparseAttnConfig)
182+
178183
overridden_neuron_config = overridden_neuron_config or {}
184+
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
185+
if sparse_attn:
186+
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
187+
**sparse_attn)
188+
189+
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
190+
if kv_cache_quant:
191+
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
192+
**kv_cache_quant)
193+
194+
continuous_batching = overridden_neuron_config.pop("continuous_batching",
195+
{})
196+
if continuous_batching:
197+
overridden_neuron_config[
198+
"continuous_batching"] = ContinuousBatchingConfig(
199+
**continuous_batching)
200+
201+
quant = overridden_neuron_config.pop("quant", {})
202+
if quant:
203+
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
204+
205+
on_device_generation = overridden_neuron_config.pop(
206+
"on_device_generation", {})
207+
if on_device_generation:
208+
overridden_neuron_config["on_device_generation"] = GenerationConfig(
209+
**on_device_generation)
179210
default_neuron_config.update(overridden_neuron_config)
180211
return NeuronConfig(**default_neuron_config)
181212

0 commit comments

Comments
 (0)