Skip to content

[RFC]: How to handle the compilation of PyTorch/XLA in vLLM #16282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 task done
yaochengji opened this issue Apr 8, 2025 · 4 comments
Closed
1 task done

[RFC]: How to handle the compilation of PyTorch/XLA in vLLM #16282

yaochengji opened this issue Apr 8, 2025 · 4 comments
Labels
RFC tpu Related to Google TPUs

Comments

@yaochengji
Copy link
Collaborator

yaochengji commented Apr 8, 2025

Motivation.

vLLM currently utilizes PyTorch/XLA to provide TPU backend support. However, PyTorch/XLA differs significantly from native PyTorch in terms of usage. PyTorch/XLA is a compilation only framework, it doesn't have a real eager mode. In particular, for LLM serving services, recompilation should be avoided once the server is running.
When compiling, it's important to consider which code might create PyTorch operations (e.g., tensor.copy(), tensor[:index], torch.ones(...)) and when graph capture and compilation is triggered (e.g., xm.mark_step(), xla_tensor.cpu(), if xla_tensor:, torch.compile(backend="openxla")). Due to the complexity of PyTorch/XLA, this document will only provide basic rules to simplify vLLM development on TPU.

Ways to avoid recompilation

The model executor has two primary components:

  • preparing the model and sampler inputs
  • executing the model and sampler.

Step 1

It is recommended to avoid TPU operations when preparing the model and sampler inputs. CPU tensors can be prepared and transferred to the XLA device using cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids compilation.

Step 2

The TPU execution should be decomposed into subgraphs (4 at the moment):

  • the main model
  • selecting hidden states for each request
  • sampler
  • encoder.
    Each subgraph should be decorated in a torch.compile. This is used to make sure that we have the same subgraph topology in both dummy_run and execute_model. The results from these subgraphs should either be passed to other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for subsequent processing on the CPU.

Step 3

The dummy_run should be comprehensive, ensuring all potential input shapes and branch predictions are included as subgraph inputs to facilitate pre-compilation.

Feedback Period.

No response

CC List.

@robertgshaw2-redhat @NickLucche @WoosukKwon @yarongmu-google @bvrockwell

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@yaochengji yaochengji added the RFC label Apr 8, 2025
@yaochengji
Copy link
Collaborator Author

Related PR: #16275

@yaochengji yaochengji added the tpu Related to Google TPUs label Apr 9, 2025
@youkaichao
Copy link
Member

this is not really an RFC, but design doc for TPU compilation. we can write it down in vllm/v1/worker/tpu_model_runner.py

@yaochengji
Copy link
Collaborator Author

@youkaichao , thanks for the suggestion. I can submit a PR to add it.

@yaochengji
Copy link
Collaborator Author

Submitted: #16614

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC tpu Related to Google TPUs
Projects
None yet
Development

No branches or pull requests

2 participants