@@ -174,8 +174,39 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
174
174
175
175
def _get_neuron_config_after_override (default_neuron_config ,
176
176
overridden_neuron_config ):
177
- from transformers_neuronx .config import NeuronConfig
177
+ from transformers_neuronx .config import (ContinuousBatchingConfig ,
178
+ GenerationConfig ,
179
+ KVCacheQuantizationConfig ,
180
+ NeuronConfig , QuantizationConfig ,
181
+ SparseAttnConfig )
182
+
178
183
overridden_neuron_config = overridden_neuron_config or {}
184
+ sparse_attn = overridden_neuron_config .pop ("sparse_attn" , {})
185
+ if sparse_attn :
186
+ overridden_neuron_config ["sparse_attn" ] = SparseAttnConfig (
187
+ ** sparse_attn )
188
+
189
+ kv_cache_quant = overridden_neuron_config .pop ("kv_cache_quant" , {})
190
+ if kv_cache_quant :
191
+ overridden_neuron_config ["kv_cache_quant" ] = KVCacheQuantizationConfig (
192
+ ** kv_cache_quant )
193
+
194
+ continuous_batching = overridden_neuron_config .pop ("continuous_batching" ,
195
+ {})
196
+ if continuous_batching :
197
+ overridden_neuron_config [
198
+ "continuous_batching" ] = ContinuousBatchingConfig (
199
+ ** continuous_batching )
200
+
201
+ quant = overridden_neuron_config .pop ("quant" , {})
202
+ if quant :
203
+ overridden_neuron_config ["quant" ] = QuantizationConfig (** quant )
204
+
205
+ on_device_generation = overridden_neuron_config .pop (
206
+ "on_device_generation" , {})
207
+ if on_device_generation :
208
+ overridden_neuron_config ["on_device_generation" ] = GenerationConfig (
209
+ ** on_device_generation )
179
210
default_neuron_config .update (overridden_neuron_config )
180
211
return NeuronConfig (** default_neuron_config )
181
212
0 commit comments