|
| 1 | +import multiprocessing as mp |
| 2 | +from queue import Empty |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +import vllm.envs as envs |
| 6 | +from vllm.distributed.communication_op import broadcast_tensor_dict |
| 7 | +from vllm.engine.arg_utils import EngineArgs |
| 8 | +from vllm.engine.llm_engine import LLMEngine |
| 9 | +from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor |
| 10 | +from vllm.executor.ray_gpu_executor import RayGPUExecutor |
| 11 | +from vllm.inputs import PromptType, TokensPrompt |
| 12 | +from vllm.logger import init_logger |
| 13 | +from vllm.pooling_params import PoolingParams |
| 14 | +from vllm.sampling_params import SamplingParams |
| 15 | +from vllm.usage.usage_lib import UsageContext |
| 16 | +from vllm.utils import Counter |
| 17 | + |
| 18 | +logger = init_logger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +class FastSyncLLM: |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + engine_args: EngineArgs, |
| 26 | + input_queue: mp.Queue, |
| 27 | + result_queue: mp.Queue, |
| 28 | + **kwargs, |
| 29 | + ) -> None: |
| 30 | + if "disable_log_stats" not in kwargs: |
| 31 | + kwargs["disable_log_stats"] = True |
| 32 | + self.engine_args = engine_args |
| 33 | + self.request_counter = Counter() |
| 34 | + |
| 35 | + self.input_queue = input_queue |
| 36 | + self.result_queue = result_queue |
| 37 | + self.finish = False |
| 38 | + self.need_restart = False |
| 39 | + self.llm_engine: LLMEngine |
| 40 | + |
| 41 | + def _add_request( |
| 42 | + self, |
| 43 | + inputs: PromptType, |
| 44 | + params: Union[SamplingParams, PoolingParams], |
| 45 | + request_id: str, |
| 46 | + ) -> None: |
| 47 | + if isinstance(inputs, list): |
| 48 | + inputs = TokensPrompt(prompt_token_ids=inputs) |
| 49 | + self.llm_engine.add_request(request_id, inputs, params) |
| 50 | + |
| 51 | + def _poll_requests(self): |
| 52 | + while True: |
| 53 | + if not self.llm_engine.has_unfinished_requests(): |
| 54 | + logger.info("No unfinished requests. Waiting...") |
| 55 | + (request_id, prompt, sampling_params) = self.input_queue.get() |
| 56 | + if self.need_restart and isinstance( |
| 57 | + self.llm_engine.model_executor, |
| 58 | + MultiprocessingGPUExecutor): |
| 59 | + logger.info("Restarting worker loops") |
| 60 | + for worker in self.llm_engine.model_executor.workers: |
| 61 | + worker.execute_method("start_worker_execution_loop") |
| 62 | + self.need_restart = False |
| 63 | + |
| 64 | + else: |
| 65 | + try: |
| 66 | + (request_id, prompt, |
| 67 | + sampling_params) = self.input_queue.get_nowait() |
| 68 | + except Empty: |
| 69 | + break |
| 70 | + self._add_request(prompt, sampling_params, request_id) |
| 71 | + |
| 72 | + def run_engine(self): |
| 73 | + self.llm_engine = LLMEngine.from_engine_args( |
| 74 | + self.engine_args, usage_context=UsageContext.LLM_CLASS) |
| 75 | + assert not isinstance( |
| 76 | + self.llm_engine.model_executor, |
| 77 | + RayGPUExecutor), "Ray is not supported in sync openai mode" |
| 78 | + |
| 79 | + self.result_queue.put(("Ready", None, None)) |
| 80 | + prompt_lens = {} |
| 81 | + tokens = {} # type: ignore |
| 82 | + log_interval = 100 |
| 83 | + poll_interval = envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS |
| 84 | + try: |
| 85 | + while True: |
| 86 | + poll_interval -= 1 |
| 87 | + if (self.input_queue.qsize() >= |
| 88 | + envs.VLLM_SYNC_SERVER_ACCUM_REQUESTS |
| 89 | + or poll_interval <= 0 |
| 90 | + or not self.llm_engine.has_unfinished_requests()): |
| 91 | + self._poll_requests() |
| 92 | + poll_interval = \ |
| 93 | + envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS |
| 94 | + step_outputs = self.llm_engine.step() |
| 95 | + log_interval -= 1 |
| 96 | + if log_interval == 0: |
| 97 | + log_interval = 100 |
| 98 | + logger.info("Step finished. Unfinished requests: %d", |
| 99 | + self.llm_engine.get_num_unfinished_requests()) |
| 100 | + if not self.llm_engine.has_unfinished_requests(): |
| 101 | + logger.info("Broadcast stop") |
| 102 | + broadcast_tensor_dict({}, src=0) |
| 103 | + self.need_restart = True |
| 104 | + for output in step_outputs: |
| 105 | + assert len(output.outputs) == 1 # type: ignore |
| 106 | + first_out = output.outputs[0] # type: ignore |
| 107 | + stats = None |
| 108 | + result = first_out.text |
| 109 | + tokens[output.request_id] = tokens.get( |
| 110 | + output.request_id, 0) + len(first_out.token_ids) |
| 111 | + if output.prompt_token_ids is not None: |
| 112 | + prompt_lens[output.request_id] = len( |
| 113 | + output.prompt_token_ids) |
| 114 | + if output.finished: |
| 115 | + assert output.request_id in prompt_lens |
| 116 | + stats = { |
| 117 | + "prompt": prompt_lens[output.request_id], |
| 118 | + "tokens": tokens[output.request_id], |
| 119 | + "finish_reason": first_out.finish_reason, |
| 120 | + "stop_reason": first_out.stop_reason, |
| 121 | + } |
| 122 | + del prompt_lens[output.request_id] |
| 123 | + self.result_queue.put_nowait( |
| 124 | + (output.request_id, result, stats)) |
| 125 | + |
| 126 | + except Exception as e: |
| 127 | + logger.error("Error in run_engine: %s", e) |
| 128 | + raise e |
0 commit comments