|
| 1 | +from enum import Enum |
| 2 | +from typing import List |
| 3 | + |
| 4 | +from evaluator import equal_group |
| 5 | +from transformers import AutoTokenizer |
| 6 | + |
| 7 | +from tensorrt_llm.scaffolding import Controller, GenerationTask |
| 8 | + |
| 9 | + |
| 10 | +class DynasorGenerationController(Controller): |
| 11 | + |
| 12 | + class WorkerTag(Enum): |
| 13 | + GENERATION = "generation_with_dynasor_cot" |
| 14 | + |
| 15 | + # Certainty_threshold and chunk_size controls the compute saving level |
| 16 | + # Decreasing the certainty_threshold and chunk_size will save tokens but may risk at compromising accuracy. |
| 17 | + def __init__(self, |
| 18 | + generation_dir, |
| 19 | + max_tokens=8192, |
| 20 | + certainty_threshold=3, |
| 21 | + chunk_size=64): |
| 22 | + """ |
| 23 | + Initializes the controller with parameters controlling token limits and certainty thresholds. |
| 24 | +
|
| 25 | + Args: |
| 26 | + max_tokens (int): Maximum number of tokens to generate in total. |
| 27 | + certainty_threshold (int): Number of consecutive identical and confident probe answers |
| 28 | + required to consider the generation as certain. |
| 29 | + chunk_size (int): Number of tokens to generate per proposal round. |
| 30 | + """ |
| 31 | + super().__init__() |
| 32 | + self.generation_dir = generation_dir |
| 33 | + self.max_tokens = max_tokens |
| 34 | + self.certainty_threshold = certainty_threshold |
| 35 | + self.chunk_size = chunk_size |
| 36 | + self.uncertain_words = ["wait", "hold", "but", "okay", "no", "hmm"] |
| 37 | + self.probe_suffix = "... Oh, I suddenly got the answer to the whole problem, **Final Answer**\n\n\\[ \\boxed{" |
| 38 | + self.answer_suffix = "\n\n... Oh, I have got the answer to the whole problem\n**Final Answer:**\n\\[\n \\boxed{" |
| 39 | + self.answer_suffix_with_marker = "\n\n...</think>\n Oh, I have got the answer to the whole problem\n**Final Answer:**\n\\[\n \\boxed{" |
| 40 | + self.tokenizer = AutoTokenizer.from_pretrained( |
| 41 | + self.generation_dir, |
| 42 | + legacy=False, |
| 43 | + padding_side='left', |
| 44 | + truncation_side='left', |
| 45 | + trust_remote_code=False, |
| 46 | + use_fast=True, |
| 47 | + ) |
| 48 | + |
| 49 | + def process(self, tasks: List[GenerationTask], **kwargs): |
| 50 | + """ |
| 51 | + Process the generation task using an iterative approach: |
| 52 | + 1. Generate a probe response with an extra suffix to simulate chain-of-thought. |
| 53 | + 2. Evaluate the probe response to extract a potential answer. |
| 54 | + 3. Check for consistency over several rounds (using certainty_threshold). |
| 55 | + 4. If consistent, finalize the answer and return. Otherwise, continue appending new proposals. |
| 56 | +
|
| 57 | + Args: |
| 58 | + tasks (List[GenerationTask]): A list of generation tasks to process. |
| 59 | + The first task is assumed to hold the initial prompt. |
| 60 | +
|
| 61 | + Yields: |
| 62 | + A list of GenerationTask objects to be executed in further processing steps. |
| 63 | + """ |
| 64 | + # Start with the initial prompt provided by the first task. |
| 65 | + initial_prompt = tasks[0].input_str |
| 66 | + |
| 67 | + proposer_task = GenerationTask() |
| 68 | + proposer_task.max_tokens = self.chunk_size |
| 69 | + proposer_task.temperature = 0.6 |
| 70 | + proposer_task.top_p = 0.95 |
| 71 | + proposer_task.worker_tag = self.WorkerTag.GENERATION |
| 72 | + |
| 73 | + probe_task = GenerationTask() |
| 74 | + probe_task.max_tokens = 20 |
| 75 | + probe_task.temperature = 0.6 |
| 76 | + probe_task.top_p = 0.95 |
| 77 | + probe_task.worker_tag = self.WorkerTag.GENERATION |
| 78 | + |
| 79 | + probe_answers = [] |
| 80 | + probe_responses = [] |
| 81 | + |
| 82 | + initial_prompt_token_num = len( |
| 83 | + self.tokenizer.encode(initial_prompt, add_special_tokens=False)) |
| 84 | + probe_suffix_token_num = len( |
| 85 | + self.tokenizer.encode(self.probe_suffix, add_special_tokens=False)) |
| 86 | + |
| 87 | + current_prompt = initial_prompt |
| 88 | + |
| 89 | + # Iterate over generation rounds until the maximum tokens limit is reached. |
| 90 | + # Make sure length of prefilling is always smaller than the max_tokens in TRTLLMWorker.init_with_new_llm |
| 91 | + # Otherwise it will through an assertion fail, stated in issue #3576 |
| 92 | + for _ in range(initial_prompt_token_num + probe_suffix_token_num, |
| 93 | + self.max_tokens, self.chunk_size): |
| 94 | + proposer_task.input_str = current_prompt |
| 95 | + probe_task.input_str = current_prompt + self.probe_suffix |
| 96 | + |
| 97 | + # For the probe task, append the suffix to force a chain-of-thought leading to an answer. |
| 98 | + yield [probe_task] |
| 99 | + |
| 100 | + # Retrieve the output from the probe task. |
| 101 | + probe_text = probe_task.output_str |
| 102 | + |
| 103 | + # Extract the potential answer from the probe response. |
| 104 | + answer = self.obtain_answer(probe_text) |
| 105 | + probe_answers.append(answer) |
| 106 | + probe_responses.append(probe_text) |
| 107 | + |
| 108 | + # Determine if the last few probe responses are considered confident enough. |
| 109 | + # A response is flagged as confident if it does not contain any of the uncertain words. |
| 110 | + probe_certain_count = [ |
| 111 | + not any(word in res.lower() for word in self.uncertain_words) |
| 112 | + for res in probe_responses[-self.certainty_threshold:] |
| 113 | + ] |
| 114 | + |
| 115 | + # Check if the last 'certainty_threshold' probe answers are identical (by equal_group) |
| 116 | + # and they are not empty, and all responses are confident. |
| 117 | + if (equal_group(probe_answers[-self.certainty_threshold:]) |
| 118 | + and self.count_not_empty( |
| 119 | + probe_answers[-self.certainty_threshold:]) |
| 120 | + == self.certainty_threshold |
| 121 | + and sum(probe_certain_count) == self.certainty_threshold): |
| 122 | + # If the current prompt indicates the chain-of-thought phase has ended, use one type of suffix. |
| 123 | + if "</think>" in current_prompt: |
| 124 | + tasks[0].output_str = (current_prompt + self.answer_suffix + |
| 125 | + probe_answers[-1] + "}\n\\]") |
| 126 | + return |
| 127 | + else: |
| 128 | + # Otherwise, use the suffix with marker to transition clearly. |
| 129 | + tasks[0].output_str = (current_prompt + |
| 130 | + self.answer_suffix_with_marker + |
| 131 | + probe_answers[-1] + "}\n\\]") |
| 132 | + return |
| 133 | + |
| 134 | + # if not confident, do another round of generation |
| 135 | + yield [proposer_task] |
| 136 | + |
| 137 | + # Append the newly generated text from the proposer to the current prompt for the next iteration. |
| 138 | + current_prompt += proposer_task.output_str |
| 139 | + |
| 140 | + # If the maximum token limit is reached without satisfying the certainty condition, |
| 141 | + # output the accumulated prompt as the final output. |
| 142 | + tasks[0].output_str = current_prompt |
| 143 | + return |
| 144 | + |
| 145 | + @staticmethod |
| 146 | + def obtain_answer(s): |
| 147 | + # Find first unpaired } by counting { and } |
| 148 | + stack = [] |
| 149 | + for i, c in enumerate(s): |
| 150 | + if c == "{": |
| 151 | + stack.append(c) |
| 152 | + elif c == "}": |
| 153 | + if not stack: # No matching { found |
| 154 | + return s[:i] |
| 155 | + stack.pop() |
| 156 | + return "" |
| 157 | + |
| 158 | + @staticmethod |
| 159 | + def count_not_empty(answers): |
| 160 | + return sum(1 for answer in answers if answer != "") |
0 commit comments