@@ -402,34 +402,36 @@ impl Validation {
402
402
None => None ,
403
403
} ;
404
404
405
- let logit_bias = match & request. parameters . logit_bias {
406
- Some ( bias) if !bias. is_empty ( ) => {
407
- for ( token_str, _) in bias. iter ( ) {
408
- let token_id = token_str. parse :: < u32 > ( ) . map_err ( |_| {
409
- ValidationError :: LogitBiasInvalid ( format ! (
410
- "Token ID {} is not a valid number." ,
411
- token_str
412
- ) )
413
- } ) ?;
414
-
415
- if token_id >= self . vocab_size {
416
- return Err ( ValidationError :: LogitBiasInvalid ( format ! (
417
- "Token ID {} is out of range. Must be between 0 and {}." ,
418
- token_id,
419
- self . vocab_size - 1
420
- ) ) ) ;
421
- }
422
- }
405
+ // Validate logit bias and convert to a vector of (token_id, bias_value)
406
+ let logit_bias = request
407
+ . parameters
408
+ . logit_bias
409
+ . as_ref ( )
410
+ . filter ( |bias_map| !bias_map. is_empty ( ) )
411
+ . map ( |bias_map| {
412
+ bias_map
413
+ . iter ( )
414
+ . map ( |( token_str, & bias_value) | {
415
+ let token_id: u32 = token_str. parse ( ) . map_err ( |_| {
416
+ ValidationError :: LogitBiasInvalid ( format ! (
417
+ "Token ID {token_str} is not a valid number."
418
+ ) )
419
+ } ) ?;
423
420
424
- // Transform into the required format
425
- Some (
426
- bias. iter ( )
427
- . map ( |( k, v) | ( k. parse :: < u32 > ( ) . unwrap ( ) , * v as f32 ) )
428
- . collect ( ) ,
429
- )
430
- }
431
- _ => None ,
432
- } ;
421
+ if token_id >= self . vocab_size {
422
+ return Err ( ValidationError :: LogitBiasInvalid ( format ! (
423
+ "Token ID {token_id} is out of range (0..{})." ,
424
+ self . vocab_size - 1
425
+ ) ) ) ;
426
+ }
427
+
428
+ Ok ( ( token_id, bias_value as f32 ) )
429
+ } )
430
+ . collect :: < Result < Vec < _ > , _ > > ( )
431
+ } )
432
+ // convert Option<Result<T, E>> to Result<Option<T>, E> to throw
433
+ // if any of the token IDs are invalid
434
+ . transpose ( ) ?;
433
435
434
436
let parameters = ValidParameters {
435
437
temperature,
0 commit comments