Skip to content

[TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues #16275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 10, 2025

Conversation

yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Apr 8, 2025

  • Wrap all the TPU computation into torch.compile, currently they're
    • forward: total_bucket_size = token_bucket_num
    • select_hidden_states: total_bucket_size = token_bucket_num * req_bucket_num
    • sample_from_hidden: total_bucket_size = req_bucket_num * 2
  • Remove pytorch operations in TPUSupportedSamplingMetadata.from_input_batch
  • Change the implementation of torch.where to if else branches

Copy link

github-actions bot commented Apr 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 8, 2025
@NickLucche
Copy link
Contributor

I can't review rn sorry. Is it faster than main?

@yaochengji
Copy link
Collaborator Author

I can't review rn sorry. Is it faster than main?

Yes, both compilation and execution

@yaochengji yaochengji changed the title [TPU][V1] Refine tpu_model_runner to mitigate the future recompilation issues [TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues Apr 8, 2025
Copy link
Contributor

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, overall this lgtm, especially the from_input_batch change.

I don't fully grasp where the optimization lies with the if/else branch as we're just forking two different graphs now, but that's about it.

@yaochengji
Copy link
Collaborator Author

yaochengji commented Apr 8, 2025

I don't fully grasp where the optimization lies with the if/else branch as we're just forking two different graphs now, but that's about it.

If we use torch.where(pred, A, B), we only need a compile one graph for both pred is True and False, but both A and B will be computed no matter what the value of pred is.

If we use

if pred:
    A
else:
    B

We need to compile two graphs for pred is True or False. The benefit is that we don't need to compute both A and B in execution.

also cc @robertgshaw2-redhat , I remember you had similar questions.

@NickLucche
Copy link
Contributor

but both A and B will be computed no matter what the value of pred is

Ok I see the issue now. The simpler if/else solution fits our use-case surely better. Thanks for explaining!

@yaochengji yaochengji added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 8, 2025
@yaochengji yaochengji marked this pull request as draft April 8, 2025 20:08
@yaochengji yaochengji marked this pull request as ready for review April 8, 2025 20:23
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems that the impl directly uses the cpu tensor and move them to the xla device

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it aligns with the description?

@mgoin
Copy link
Member

mgoin commented Apr 9, 2025

Looks like the TPU V1 test is failing


ERROR collecting tests/v1/tpu/worker/test_tpu_model_runner.py 
  | ImportError while importing test module '/workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py'.
  | Hint: make sure your test modules/packages have valid Python names.
  | Traceback:
  | /usr/local/lib/python3.10/importlib/__init__.py:126: in import_module
  | return _bootstrap._gcd_import(name[level:], package, level)
  | tests/v1/tpu/worker/test_tpu_model_runner.py:10: in <module>
  | from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
  | E   ImportError: cannot import name '_get_paddings' from 'vllm.v1.worker.tpu_model_runner' (/workspace/vllm/vllm/v1/worker/tpu_model_runner.py)

Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
@yaochengji yaochengji force-pushed the chengji/improve-recompile branch from f2a4dfd to dc5c641 Compare April 9, 2025 15:58
@yaochengji
Copy link
Collaborator Author

Looks like the TPU V1 test is failing

@mgoin thanks for reminding, it should be fixed now. BTW, can we make the TPU CI test not a soft-fail now? cc @robertgshaw2-redhat

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, I feel this is a good path to go down! Thank you

@mgoin mgoin merged commit a454748 into vllm-project:main Apr 10, 2025
41 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants