Skip to content

[Bugfix][TPU] Fix tpu model runner testcase failure #18810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 30, 2025

Conversation

CAROLZXYZXY
Copy link
Contributor

@CAROLZXYZXY CAROLZXYZXY commented May 28, 2025

  1. Previous failure is due to input_batch creation only when initializes the kv cache. Test 6 passes in CI.
  2. [v1] Redo "Support multiple KV cache groups in GPU model runner (#17945)" #18593 introduces multi group block table which causes some tests fail.
=================================== FAILURES ===================================
--
  | ________________________ test_update_states_new_request ________________________
  |  
  | model_runner = <vllm.v1.worker.tpu_model_runner.TPUModelRunner object at 0x7a7edb115b10>
  |  
  | def test_update_states_new_request(model_runner):
  | req_id = "req_0"
  |  
  | # new req
  | scheduler_output = _schedule_new_request(req_id)
  |  
  | >       model_runner._update_states(scheduler_output)
  |  
  | tests/v1/tpu/worker/test_tpu_model_runner.py:131:
  | _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
  |  
  | self = <vllm.v1.worker.tpu_model_runner.TPUModelRunner object at 0x7a7edb115b10>
  | scheduler_output = SchedulerOutput(scheduled_new_reqs=[NewRequestData(req_id=req_0,prompt_token_ids=[1, 2, 3],mm_inputs=[],mm_hashes=[],m...s=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, kv_connector_metadata=None)
  |  
  | def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
  | """Update the cached states and the persistent batch with the scheduler
  | output.
  |  
  | The updated states are used by the `_prepare_inputs` function to create
  | the input GPU tensors for the model.
  |  
  | Returns:
  | True if there is a new/resumed/paused/finished request.
  | If False, we can skip copying SamplingMetadata to the GPU.
  | """
  | # Remove finished requests from the cached states.
  | for req_id in scheduler_output.finished_req_ids:
  | self.requests.pop(req_id, None)
  | self.encoder_cache.pop(req_id, None)
  |  
  | # Remove the finished requests from the persistent batch.
  | # NOTE(woosuk): There could be an edge case where finished_req_ids and
  | # scheduled_req_ids overlap. This happens when a request is aborted and
  | # then resubmitted with the same ID. In this case, we treat them as two
  | # distinct requests - clearing the cached states for the first request
  | # and handling the second as a new request.
  | removed_req_indices: list[int] = []
  | for req_id in scheduler_output.finished_req_ids:
  | req_index = self.input_batch.remove_request(req_id)
  | if req_index is not None:
  | removed_req_indices.append(req_index)
  |  
  | # Free the cached encoder outputs.
  | for req_id, input_id in scheduler_output.free_encoder_input_ids:
  | encoder_outputs = self.encoder_cache.get(req_id)
  | if encoder_outputs is not None:
  | encoder_outputs.pop(input_id, None)
  | if not encoder_outputs:
  | self.encoder_cache.pop(req_id, None)
  |  
  | # Remove the unscheduled requests from the persistent batch.
  | # NOTE(woosuk): The unscheduled requests are either preempted requests
  | # or running requests that are not scheduled in this step. We remove
  | # them from the persistent batch but keep their cached states since
  | # they will be scheduled again sometime in the future.
  | scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
  | >       cached_req_ids = self.input_batch.req_id_to_index.keys()
  | E       AttributeError: 'TPUModelRunner' object has no attribute 'input_batch'
  |  
  | vllm/v1/worker/tpu_model_runner.py:327: AttributeError

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@CAROLZXYZXY CAROLZXYZXY force-pushed the cazheng/fix-tpu-tc6 branch from d615026 to c2803c3 Compare May 28, 2025 06:33
@mergify mergify bot added v1 tpu Related to Google TPUs labels May 28, 2025
@CAROLZXYZXY CAROLZXYZXY marked this pull request as draft May 28, 2025 06:38
@CAROLZXYZXY CAROLZXYZXY force-pushed the cazheng/fix-tpu-tc6 branch 2 times, most recently from d9c22b9 to 0b096db Compare May 28, 2025 17:46
@CAROLZXYZXY CAROLZXYZXY marked this pull request as ready for review May 28, 2025 18:21
@yaochengji yaochengji added the ready ONLY add when PR is ready to merge/full CI is needed label May 28, 2025
@CAROLZXYZXY CAROLZXYZXY force-pushed the cazheng/fix-tpu-tc6 branch from 0b096db to 9cb81f1 Compare May 28, 2025 23:16
@CAROLZXYZXY CAROLZXYZXY force-pushed the cazheng/fix-tpu-tc6 branch 2 times, most recently from db70231 to 084fba5 Compare May 29, 2025 06:20
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this!

@yaochengji yaochengji changed the title Fix tpu model runner testcase failure [Bugfix][TPU] Fix tpu model runner testcase failure May 29, 2025
Signed-off-by: Carol Zheng <[email protected]>
Signed-off-by: Carol Zheng <[email protected]>
Signed-off-by: Carol Zheng <[email protected]>
Signed-off-by: Carol Zheng <[email protected]>
Signed-off-by: Carol Zheng <[email protected]>
@CAROLZXYZXY CAROLZXYZXY force-pushed the cazheng/fix-tpu-tc6 branch from 084fba5 to c8dcb5d Compare May 29, 2025 20:15
@DarkLight1337 DarkLight1337 merged commit fba02e3 into vllm-project:main May 30, 2025
62 checks passed
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants