@@ -184,6 +184,9 @@ class SamplingParams(
184
184
allowed_token_ids: If provided, the engine will construct a logits
185
185
processor which only retains scores for the given token ids.
186
186
Defaults to None.
187
+ extra_args: Arbitrary additional args, that can be used by custom
188
+ sampling implementations. Not used by any in-tree sampling
189
+ implementations.
187
190
"""
188
191
189
192
n : int = 1
@@ -227,6 +230,7 @@ class SamplingParams(
227
230
guided_decoding : Optional [GuidedDecodingParams ] = None
228
231
logit_bias : Optional [dict [int , float ]] = None
229
232
allowed_token_ids : Optional [list [int ]] = None
233
+ extra_args : Optional [dict [str , Any ]] = None
230
234
231
235
@staticmethod
232
236
def from_optional (
@@ -259,6 +263,7 @@ def from_optional(
259
263
guided_decoding : Optional [GuidedDecodingParams ] = None ,
260
264
logit_bias : Optional [Union [dict [int , float ], dict [str , float ]]] = None ,
261
265
allowed_token_ids : Optional [list [int ]] = None ,
266
+ extra_args : Optional [dict [str , Any ]] = None ,
262
267
) -> "SamplingParams" :
263
268
if logit_bias is not None :
264
269
# Convert token_id to integer
@@ -300,6 +305,7 @@ def from_optional(
300
305
guided_decoding = guided_decoding ,
301
306
logit_bias = logit_bias ,
302
307
allowed_token_ids = allowed_token_ids ,
308
+ extra_args = extra_args ,
303
309
)
304
310
305
311
def __post_init__ (self ) -> None :
@@ -509,7 +515,8 @@ def __repr__(self) -> str:
509
515
"spaces_between_special_tokens="
510
516
f"{ self .spaces_between_special_tokens } , "
511
517
f"truncate_prompt_tokens={ self .truncate_prompt_tokens } , "
512
- f"guided_decoding={ self .guided_decoding } )" )
518
+ f"guided_decoding={ self .guided_decoding } , "
519
+ f"extra_args={ self .extra_args } )" )
513
520
514
521
515
522
class BeamSearchParams (
0 commit comments