Skip to content

Commit a174f63

Browse files
committed
fix: improve validation and transform logic
1 parent f27586e commit a174f63

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

router/src/validation.rs

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -402,34 +402,36 @@ impl Validation {
402402
None => None,
403403
};
404404

405-
let logit_bias = match &request.parameters.logit_bias {
406-
Some(bias) if !bias.is_empty() => {
407-
for (token_str, _) in bias.iter() {
408-
let token_id = token_str.parse::<u32>().map_err(|_| {
409-
ValidationError::LogitBiasInvalid(format!(
410-
"Token ID {} is not a valid number.",
411-
token_str
412-
))
413-
})?;
414-
415-
if token_id >= self.vocab_size {
416-
return Err(ValidationError::LogitBiasInvalid(format!(
417-
"Token ID {} is out of range. Must be between 0 and {}.",
418-
token_id,
419-
self.vocab_size - 1
420-
)));
421-
}
422-
}
405+
// Validate logit bias and convert to a vector of (token_id, bias_value)
406+
let logit_bias = request
407+
.parameters
408+
.logit_bias
409+
.as_ref()
410+
.filter(|bias_map| !bias_map.is_empty())
411+
.map(|bias_map| {
412+
bias_map
413+
.iter()
414+
.map(|(token_str, &bias_value)| {
415+
let token_id: u32 = token_str.parse().map_err(|_| {
416+
ValidationError::LogitBiasInvalid(format!(
417+
"Token ID {token_str} is not a valid number."
418+
))
419+
})?;
423420

424-
// Transform into the required format
425-
Some(
426-
bias.iter()
427-
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
428-
.collect(),
429-
)
430-
}
431-
_ => None,
432-
};
421+
if token_id >= self.vocab_size {
422+
return Err(ValidationError::LogitBiasInvalid(format!(
423+
"Token ID {token_id} is out of range (0..{}).",
424+
self.vocab_size - 1
425+
)));
426+
}
427+
428+
Ok((token_id, bias_value as f32))
429+
})
430+
.collect::<Result<Vec<_>, _>>()
431+
})
432+
// convert Option<Result<T, E>> to Result<Option<T>, E> to throw
433+
// if any of the token IDs are invalid
434+
.transpose()?;
433435

434436
let parameters = ValidParameters {
435437
temperature,

0 commit comments

Comments
 (0)