Skip to content

Commit e0c3b66

Browse files
committed
fix: improve processor logic and refactor
1 parent 4aff36a commit e0c3b66

File tree

4 files changed

+110
-101
lines changed

4 files changed

+110
-101
lines changed

router/src/server.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,6 @@ pub(crate) async fn chat_completions(
11921192
let (generate_request, using_tools): (GenerateRequest, bool) =
11931193
chat.clone().try_into_generate(&infer)?;
11941194
span.record("parameters", format!("{:?}", generate_request.parameters));
1195-
println!("ChatRequest: {:#?}", generate_request);
11961195
let logprobs = logprobs.unwrap_or_default();
11971196

11981197
// extract model id from request if specified

router/src/validation.rs

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub struct Validation {
3434
max_input_length: usize,
3535
max_total_tokens: usize,
3636
disable_grammar_support: bool,
37+
vocab_size: u32,
3738
/// Channel to communicate with the background tokenization task
3839
sender: mpsc::UnboundedSender<TokenizerRequest>,
3940
}
@@ -88,6 +89,19 @@ impl Validation {
8889
validation_sender
8990
};
9091

92+
let vocab_size = match &tokenizer {
93+
Tokenizer::Python { tokenizer_name, .. } => {
94+
warn!(
95+
"Tokenizer {} is not supported for validation",
96+
tokenizer_name
97+
);
98+
0
99+
}
100+
Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false),
101+
}
102+
.try_into()
103+
.unwrap_or(0);
104+
91105
Self {
92106
max_best_of,
93107
sender,
@@ -96,6 +110,7 @@ impl Validation {
96110
max_input_length,
97111
max_total_tokens,
98112
disable_grammar_support,
113+
vocab_size,
99114
}
100115
}
101116

@@ -387,6 +402,35 @@ impl Validation {
387402
None => None,
388403
};
389404

405+
let logit_bias = match &request.parameters.logit_bias {
406+
Some(bias) if !bias.is_empty() => {
407+
for (token_str, _) in bias.iter() {
408+
let token_id = token_str.parse::<u32>().map_err(|_| {
409+
ValidationError::LogitBiasInvalid(format!(
410+
"Token ID {} is not a valid number.",
411+
token_str
412+
))
413+
})?;
414+
415+
if token_id >= self.vocab_size {
416+
return Err(ValidationError::LogitBiasInvalid(format!(
417+
"Token ID {} is out of range. Must be between 0 and {}.",
418+
token_id,
419+
self.vocab_size - 1
420+
)));
421+
}
422+
}
423+
424+
// Transform into the required format
425+
Some(
426+
bias.iter()
427+
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
428+
.collect(),
429+
)
430+
}
431+
_ => None,
432+
};
433+
390434
let parameters = ValidParameters {
391435
temperature,
392436
repetition_penalty,
@@ -398,18 +442,7 @@ impl Validation {
398442
seed,
399443
watermark,
400444
grammar,
401-
logit_bias: Some(
402-
request
403-
.parameters
404-
.logit_bias
405-
.iter()
406-
.flat_map(|bias| {
407-
bias.iter()
408-
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
409-
.collect::<Vec<_>>()
410-
})
411-
.collect(),
412-
),
445+
logit_bias,
413446
};
414447
let stopping_parameters = ValidStoppingParameters {
415448
max_new_tokens,
@@ -989,6 +1022,8 @@ pub enum ValidationError {
9891022
FailedFetchImage(#[from] reqwest::Error),
9901023
#[error("{0} modality is not supported")]
9911024
UnsupportedModality(&'static str),
1025+
#[error("logit_bias is not valid: {0}")]
1026+
LogitBiasInvalid(String),
9921027
}
9931028

9941029
#[cfg(test)]

server/text_generation_server/utils/logits_process.py

Lines changed: 60 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -625,110 +625,92 @@ def filter(self, indices):
625625
return self
626626

627627

628-
class LogitBiasProcessor:
629-
"""Process logits with logit biases."""
628+
class LogitBiasProcessor(LogitsProcessor):
629+
"""
630+
`LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
631+
corresponding bias values. Bias are applied to the logits during each forward pass.
632+
633+
Supports token IDs provided as strings (e.g., {"9707": -100}).
634+
"""
630635

631636
def __init__(
632-
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
637+
self,
638+
logit_biases: dict,
639+
tokenizer: PreTrainedTokenizerBase,
640+
device: torch.device,
633641
):
634-
self.tokenizer = tokenizer
635-
self.logit_biases = logit_biases or {}
642+
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
636643

637-
# Pre-compute token IDs for each token string
638-
self.token_id_mapping = {}
644+
vocab_size = len(tokenizer)
639645

640-
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
641-
# If no logit biases, return scores unchanged
642-
if not self.logit_biases:
643-
return scores
644-
645-
# Apply bias to the corresponding scores
646-
for token_str, bias_value in self.logit_biases.items():
647-
# Get token ID, either from cache or by computing it
648-
if token_str not in self.token_id_mapping:
649-
if token_str.isdigit():
650-
# If the token string is already a numeric ID
651-
token_id = int(token_str)
652-
else:
653-
# Otherwise, use the tokenizer to get the ID
654-
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
655-
token_id = tokens[0] if tokens else -1 # Use -1 for not found
656-
657-
self.token_id_mapping[token_str] = token_id
658-
659-
token_id = self.token_id_mapping[token_str]
660-
661-
# Apply bias if token ID is valid
662-
if 0 <= token_id < scores.size(-1):
663-
scores[:, token_id] += bias_value
646+
# Convert keys to integers and values to a list
647+
token_ids = torch.tensor(
648+
[int(k) for k in logit_biases.keys()], dtype=torch.long
649+
)
650+
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
664651

665-
return scores
652+
# Create a tensor and directly copy bias values at the corresponding indices
653+
self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float)
654+
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
666655

667-
def filter(self, indices):
668-
"""Keep only the logit biases for the specified indices."""
669-
new_logit_biases = {
670-
k: self.logit_biases[k] for k in indices if k in self.logit_biases
671-
}
672-
return LogitBiasProcessor(new_logit_biases, self.tokenizer)
656+
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
657+
# Apply bias tensor as a broadcasted addition
658+
if self.bias_tensor.shape[0] != scores.shape[1]:
659+
# Fix if the bias tensor is smaller than the scores
660+
self.bias_tensor = torch.nn.functional.pad(
661+
self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0])
662+
)
663+
scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype))
664+
return scores
673665

674666

675-
class HeterogeneousLogitBiasProcessor:
676-
"""Process logits with different logit biases for each sequence in the batch."""
667+
class HeterogeneousLogitBiasProcessor(LogitsProcessor):
668+
"""
669+
Process logits with different logit biases for each sequence in the batch.
670+
"""
677671

678672
def __init__(
679673
self,
680674
logit_biases: List[Optional[dict]],
681675
tokenizer: PreTrainedTokenizerBase,
682676
device: torch.device,
683677
):
684-
self.device = device
685678
self.tokenizer = tokenizer
686679
self.logit_biases = logit_biases
687-
self.batch_size = len(logit_biases)
680+
# import ipdb; ipdb.set_trace()
681+
self.vocab_size = len(tokenizer)
688682

689-
# Pre-compute token IDs for each token string
690-
self.token_id_mapping = {}
683+
# Create batch_size x vocab_size bias matrix
684+
self.bias_matrix = torch.zeros(
685+
(len(logit_biases), self.vocab_size), dtype=torch.float, device=device
686+
)
691687

692-
# Create a mapping of indices that have logit biases
693-
self.indices_with_biases = {
694-
i: bias_dict
695-
for i, bias_dict in enumerate(self.logit_biases)
696-
if bias_dict is not None and len(bias_dict) > 0
697-
}
688+
# for each logit bias dictionary, convert keys to integers and values to a list
689+
for i, logit_bias in enumerate(logit_biases):
690+
token_ids = torch.tensor(
691+
[int(k) for k in logit_bias.keys()], dtype=torch.long
692+
).to(device=device)
693+
bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to(
694+
device=device
695+
)
696+
# Create a tensor and directly copy bias values at the corresponding indices
697+
self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True)
698698

699699
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
700-
# If no indices have biases, return scores unchanged
701-
if not self.indices_with_biases:
702-
return scores
703-
704-
# For each index with a bias, apply the bias to the corresponding scores
705-
for i, bias_dict in self.indices_with_biases.items():
706-
for token_str, bias_value in bias_dict.items():
707-
# Get token ID, either from cache or by computing it
708-
if token_str not in self.token_id_mapping:
709-
if token_str.isdigit():
710-
# If the token string is already a numeric ID
711-
token_id = int(token_str)
712-
else:
713-
# Otherwise, use the tokenizer to get the ID
714-
tokens = self.tokenizer.encode(
715-
token_str, add_special_tokens=False
716-
)
717-
token_id = tokens[0] if tokens else -1 # Use -1 for not found
718-
719-
self.token_id_mapping[token_str] = token_id
720-
721-
token_id = self.token_id_mapping[token_str]
722-
723-
# Apply bias if token ID is valid
724-
if 0 <= token_id < scores.size(-1):
725-
scores[i, token_id] += bias_value
700+
# Apply bias matrix as a broadcasted addition
701+
if self.bias_matrix.shape[1] != scores.shape[1]:
702+
# Fix if the bias matrix is smaller than the scores
703+
self.bias_matrix = torch.nn.functional.pad(
704+
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1])
705+
)
726706

707+
scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype))
727708
return scores
728709

729-
def filter(self, indices: List[int]):
730-
"""Keep only the logit biases for the specified indices."""
710+
def filter(self, indices):
731711
new_logit_biases = [self.logit_biases[i] for i in indices]
712+
if not any(bias and len(bias) > 0 for bias in new_logit_biases):
713+
return None
732714
return HeterogeneousLogitBiasProcessor(
733715
new_logit_biases, self.tokenizer, self.device
734716
)

server/text_generation_server/utils/tokens.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def __init__(
6666
else None
6767
)
6868
self.tokenizer = tokenizer
69-
self.logit_bias = logit_bias
7069

7170
has_warpers = (
7271
(temperature is not None and temperature != 1.0)
@@ -136,7 +135,7 @@ def from_pb(
136135
tokenizer=tokenizer,
137136
grammar=pb.grammar,
138137
grammar_type=pb.grammar_type,
139-
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
138+
logit_bias=pb.logit_bias,
140139
)
141140

142141

@@ -264,10 +263,6 @@ def __init__(
264263
):
265264
warpers = []
266265

267-
# Initialize with empty logit biases if none provided
268-
if logit_biases is None:
269-
logit_biases = [None] * len(do_sample)
270-
271266
self.watermark_processor = (
272267
HeterogeneousProcessorWrapper(
273268
{
@@ -306,7 +301,7 @@ def __init__(
306301

307302
self.logit_bias_processor = (
308303
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
309-
if any([bias is not None and len(bias) > 0 for bias in logit_biases])
304+
if any([logit_bias is not None for logit_bias in logit_biases])
310305
else None
311306
)
312307

@@ -530,9 +525,7 @@ def from_pb(
530525
fsm_grammar_states=(
531526
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
532527
),
533-
logit_biases=[
534-
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
535-
],
528+
logit_biases=[pb_.logit_bias for pb_ in pb],
536529
)
537530

538531

0 commit comments

Comments
 (0)