Skip to content

[V1][TPU] Support V1 Sampler for ragged attention #14227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 20, 2025

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Mar 4, 2025

Updated version of #13982 with ragged kernel attention.

All considerations described in previous PR attempt are still valid (in particular still aiming to get a single optimized graph with model+sampling), the main difference is in how the number of sampled tokens should be handled now that it's not always equal to the batch size.

I think there's fundamentally one trade-off to consider:

  • A: we sample on all the tokens in the unified batch, even those that do not need it. This way the input to the sampler is fixed so we don't recompile. Num tokens that need sampling range from 1..max_num_seqs.
  • B: we select only the logits that require sampling before model.sample and eat the cost of compilation (as it happens on main now).

I think there's also a third option which I tried to implement here:

  • C: we pre-compile a bunch of fixed-sized tensors, and we only feed those sizes to the Sampler. It's basically the logits thatneed sampling are selected + padding to the nearest pre-compiled shape. This means the input to the sampler and consequently the output of the ModelWrapper is of some pre-compiled fixed shape. Eg. do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]

Note: compilation test is currently failing (re-running a request with different sampling params isn't as fast), will dig more into it. Execution just isn't as fast as it was in previous PR: compiling doesn't always drop times by >10x. This is happening on main too, so I don't think it has anything to do with this PR. As a result, I lowered the bar in test_sampler_compilation.py. We can focus on optimizing performance later (this is anyway the first PR where we try to track it).

Tested with VLLM_USE_V1=1 python -m pytest -s tests/v1/tpu/test_sampler.py.

Update:

I broke up again the model + sampling graph, so now compilation stages are separated:

INFO 03-06 08:57:12 [tpu_model_runner.py:908] Compiling the model with different input shapes.
INFO 03-06 08:57:12 [tpu_model_runner.py:913]   -- num_tokens: 16
INFO 03-06 08:57:51 [tpu_model_runner.py:913]   -- num_tokens: 32
INFO 03-06 08:58:53 [tpu_model_runner.py:913]   -- num_tokens: 64
INFO 03-06 09:01:12 [tpu_model_runner.py:921] Compilation finished in in 240.17 [secs].
INFO 03-06 09:01:12 [tpu_model_runner.py:923] Compiling sampling with different input shapes.
INFO 03-06 09:01:12 [tpu_model_runner.py:937]   -- num_tokens: 16, num_seqs: 8
INFO 03-06 09:01:12 [tpu_model_runner.py:937]   -- num_tokens: 16, num_seqs: 16
INFO 03-06 09:01:12 [tpu_model_runner.py:937]   -- num_tokens: 16, num_seqs: 32
INFO 03-06 09:01:12 [tpu_model_runner.py:937]   -- num_tokens: 16, num_seqs: 64
INFO 03-06 09:01:12 [tpu_model_runner.py:937]   -- num_tokens: 32, num_seqs: 64
INFO 03-06 09:01:12 [tpu_model_runner.py:937]   -- num_tokens: 64, num_seqs: 64
INFO 03-06 09:01:13 [tpu_model_runner.py:949] Compilation finished in in 0.39 [secs].

Copy link

github-actions bot commented Mar 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@NickLucche
Copy link
Contributor Author

Also, sadly compilation takes forever now. Waiting for new kernels to address that.

Comment on lines 913 to 805
num_reqs_to_sample = MIN_NUM_SEQS
while True:
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.max_num_tokens:
# Compile for different sampler outputs in
# [MIN_NUM_SEQS, pad(max_num_reqs)]
while True:
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
xm.mark_step()
xm.wait_device_ops()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
num_tokens *= 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right.. each time we are only passing in num_tokens, so what is the point of num_reqs_to_sample in this loop? Looking around it seems like num_tokens_to_sample is missing to be passed in.

This potential explosion of compilation makes me feel a better solution would be to completely separate the graphs for model forward and sampling. This way we can just define a "max sampling batch size" that can be separated - so we end up with M+N compilations rather than M*N

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I can do it in in M+N here as sampling will be dependent on both num_reqs_to_sample and hidden_state, so for each different hidden_state shape we still need to compile. But we will only have MN compilations of the sampling step which should be much lighter.

I did consider breaking the graph multiple times but iirc the main concern was the interaction with torch.compile.

@NickLucche
Copy link
Contributor Author

NickLucche commented Mar 6, 2025

Quickly looking into an issue with greedy sampling and the control flow introduced here #13587 .

EDIT: pushed a new commit addressing it.

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche Nice work with the sampler and the optimizations to improve compilations times. Left some comments and questions.

# Apply mask using boolean indexing
logits[~valid_token_mask] = -float('inf')
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you discover this optimization? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was inspecting the compilations with PT_XLA_DEBUG=2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we revert this or does it also help GPu performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I can add this in a separate PR if needs be now that the two samplers are separated.

s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s
assert run1 > run2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much is run1 slower? maybe it will be safer to do something like run1 > run * 2 (else you may be in the noise area if does recompile)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

performance is back with the kernel update in #14310, so I'll bring back the 0.1 check :)

kv_caches,
num_tokens: int,
) -> None:
@torch.no_grad()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you need torch.no_grad? I recall this was causing errors

Copy link

mergify bot commented Mar 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 7, 2025
@NickLucche
Copy link
Contributor Author

Thanks a lot for the review and comments @alexm-redhat !
Just rebased on top of #14310 and performances are back to normal. I also addressed your comments!

@NickLucche NickLucche force-pushed the tpu-sampler-ragged branch from c5a52a2 to a56afc2 Compare March 7, 2025 12:07
@mergify mergify bot removed the needs-rebase label Mar 7, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work building the structure for future sampling param enablement once we get performant kernels!

Eval smoke test looks good (and runs fast!)

Running generate_until requests: 100%|██████████████████████| 1319/1319 [01:44<00:00, 12.67it/s]
INFO:lm-eval:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen2-1.5B-Instruct,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5868|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.5815|±  |0.0136|

@mgoin
Copy link
Member

mgoin commented Mar 7, 2025

Full log here:

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 03-07 15:11:36 [__init__.py:256] Automatically detected platform tpu.
INFO:lm-eval:Verbosity set to INFO
INFO:lm-eval:Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`
INFO:lm-eval:Selected Tasks: ['gsm8k']
INFO:lm-eval:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
INFO:lm-eval:Initializing vllm model, with arguments: {'pretrained': 'Qwen/Qwen2-1.5B-Instruct', 'max_model_len': 2048, 'trust_remote_code': True}
WARNING 03-07 15:11:41 [arg_utils.py:1450] Setting max_num_batched_tokens to 8192 for LLM_CLASS usage context.
config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 660/660 [00:00<00:00, 7.19MB/s]
INFO 03-07 15:11:48 [config.py:576] This model supports multiple tasks: {'score', 'generate', 'embed', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 03-07 15:11:48 [config.py:1666] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 03-07 15:11:48 [tpu.py:76] [TPU] Forcing DYNAMO_ONCE compilation level
WARNING 03-07 15:11:48 [tpu.py:108] [V1][TPU] Disable prefix caching
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.29k/1.29k [00:00<00:00, 16.7MB/s]
vocab.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.78M/2.78M [00:00<00:00, 26.2MB/s]
merges.txt: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.67M/1.67M [00:00<00:00, 23.6MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.03M/7.03M [00:00<00:00, 60.3MB/s]
generation_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 242/242 [00:00<00:00, 3.50MB/s]
INFO 03-07 15:11:49 [core.py:50] Initializing a V1 LLM engine (v0.7.4.dev217+geb59b5a6c.d20250305) with config: model='Qwen/Qwen2-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=None, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=1234, served_model_name=Qwen/Qwen2-1.5B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":2,"backend":"openxla","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
INFO 03-07 15:11:49 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
WARNING 03-07 15:12:00 [tpu.py:116] Pin memory is not supported on TPU.
INFO 03-07 15:12:00 [tpu.py:39] Cannot use None backend on TPU.
INFO 03-07 15:12:00 [tpu.py:42] Using Pallas V1 backend.
INFO 03-07 15:12:00 [weight_utils.py:257] Using model weights format ['*.safetensors']
model.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.09G/3.09G [00:14<00:00, 211MB/s]
INFO 03-07 15:12:15 [weight_utils.py:273] Time spent downloading weights for Qwen/Qwen2-1.5B-Instruct: 14.914241 seconds
INFO 03-07 15:12:15 [weight_utils.py:307] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.64it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.63it/s]

INFO 03-07 15:12:16 [loader.py:422] Loading weights took 0.47 seconds
INFO 03-07 15:12:37 [kv_cache_utils.py:537] GPU KV cache size: 921,568 tokens
INFO 03-07 15:12:37 [kv_cache_utils.py:540] Maximum concurrency for 2,048 tokens per request: 449.98x
INFO 03-07 15:12:37 [tpu_model_runner.py:766] Compiling the model with different input shapes.
INFO 03-07 15:12:37 [tpu_model_runner.py:771]   -- num_tokens: 16
INFO 03-07 15:13:08 [tpu_model_runner.py:771]   -- num_tokens: 32
INFO 03-07 15:13:14 [tpu_model_runner.py:771]   -- num_tokens: 64
INFO 03-07 15:13:19 [tpu_model_runner.py:771]   -- num_tokens: 128
INFO 03-07 15:13:26 [tpu_model_runner.py:771]   -- num_tokens: 256
INFO 03-07 15:13:32 [tpu_model_runner.py:771]   -- num_tokens: 512
INFO 03-07 15:13:39 [tpu_model_runner.py:771]   -- num_tokens: 1024
INFO 03-07 15:13:47 [tpu_model_runner.py:771]   -- num_tokens: 2048
INFO 03-07 15:13:56 [tpu_model_runner.py:771]   -- num_tokens: 4096
INFO 03-07 15:14:06 [tpu_model_runner.py:771]   -- num_tokens: 8192
INFO 03-07 15:14:17 [tpu_model_runner.py:779] Compilation finished in in 99.80 [secs].
INFO 03-07 15:14:17 [tpu_model_runner.py:781] Compiling sampling with different input shapes.
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 8
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 16
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 32
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 64
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 128
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 256
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 512
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 16, num_seqs: 1024
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 32, num_seqs: 1024
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 64, num_seqs: 1024
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 128, num_seqs: 1024
INFO 03-07 15:14:17 [tpu_model_runner.py:795]   -- num_tokens: 256, num_seqs: 1024
INFO 03-07 15:14:18 [tpu_model_runner.py:795]   -- num_tokens: 512, num_seqs: 1024
INFO 03-07 15:14:18 [tpu_model_runner.py:795]   -- num_tokens: 1024, num_seqs: 1024
INFO 03-07 15:14:19 [tpu_model_runner.py:795]   -- num_tokens: 2048, num_seqs: 1024
INFO 03-07 15:14:20 [tpu_model_runner.py:795]   -- num_tokens: 4096, num_seqs: 1024
INFO 03-07 15:14:20 [tpu_model_runner.py:795]   -- num_tokens: 8192, num_seqs: 1024
INFO 03-07 15:14:21 [tpu_model_runner.py:807] Compilation finished in in 3.85 [secs].
INFO 03-07 15:14:21 [core.py:116] init engine (profile, create kv cache, warmup model) took 124.89 seconds
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.94k/7.94k [00:00<00:00, 67.0MB/s]
train-00000-of-00001.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.31M/2.31M [00:00<00:00, 36.8MB/s]
test-00000-of-00001.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 419k/419k [00:00<00:00, 257MB/s]
Generating train split: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [00:00<00:00, 218160.79 examples/s]
Generating test split: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:00<00:00, 413721.73 examples/s]
WARNING:lm-eval:Overwriting default num_fewshot of gsm8k from 5 to 5
INFO:lm-eval:Building contexts for gsm8k on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:02<00:00, 447.85it/s]
INFO:lm-eval:Running generate_until requests
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████| 1319/1319 [01:43<00:00, 12.69it/s, est. speed input: 12596.01 toks/s, output: 1618.45 toks/s]
Running generate_until requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [01:44<00:00, 12.67it/s]
INFO:lm-eval:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen2-1.5B-Instruct,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5868|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.5815|±  |0.0136|

@mgoin mgoin added tpu Related to Google TPUs ready ONLY add when PR is ready to merge/full CI is needed labels Mar 7, 2025
@mgoin mgoin enabled auto-merge (squash) March 7, 2025 17:24
Copy link

mergify bot commented Mar 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 7, 2025
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NickLucche , thanks for your contribution, I found two places causing recompilation. Is is possible to resolve them?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for chiming in this late, but I think we need a broader discussion on how to share the sampler between different hardware backends. In V0, we found that the sampler implementation got over-complicated partly because different hardware backends added some if statements here and there just for their own purposes. I do really want to prevent it from happening this time.

auto-merge was automatically disabled March 12, 2025 15:38

Head branch was pushed to by a user without write access

Signed-off-by: NickLucche <[email protected]>
@mergify mergify bot removed the needs-rebase label Mar 12, 2025
@NickLucche
Copy link
Contributor Author

NickLucche commented Mar 12, 2025

Rebased and updated.

@WoosukKwon I've separated the Sampler code for TPU into a different namespace so it doesn't make use of the CUDA path any longer. Let me know what you think.

@yaochengji I've addressed the small recompilation issue that was happening due to a mismatch in pre-processing at compile time vs runtime.
Please note that there's still a small recompilation happening on 2nd runs which I pinpointed to vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata slicing tensors to comply with V1 "data structures reuse policy" .
If that's ok, I'd like to address that in a separate PR as this one has gotten quite clunky already.

PS. I'd also like to postpone the torch.compile wrapping of the sampling graph, as it's currently unclear whether we get benefits off of it (compile time appears to go up).

@yaochengji
Copy link
Collaborator

@NickLucche Thanks for your update!

Previously when resolving the recompilation issue in logits processor, I noticed the compilation time is about 500ms. This is not acceptable in some use cases. (E.g. a typical requirement of TPOT is 100ms). Therefore, could you resolve all the recompilation issues?

BTW, there's an ongoing task to ensure that there's no recompilation after warmup, #14580.

@NickLucche
Copy link
Contributor Author

This is a different recompilation happening in pre rather than post processing, and it is dependent on existing v1 code.

@NickLucche NickLucche requested a review from WoosukKwon March 12, 2025 17:57
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link

mergify bot commented Mar 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 17, 2025
@mergify mergify bot removed the needs-rebase label Mar 17, 2025
Signed-off-by: NickLucche <[email protected]>
@hyeygit
Copy link
Contributor

hyeygit commented Mar 19, 2025

Thank you @NickLucche for this PR! Do you plan to merge it soon? I hope to send a follow-up PR to optimize top-k for TPU.

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) March 20, 2025 02:47
@robertgshaw2-redhat robertgshaw2-redhat merged commit d8c6d7d into vllm-project:main Mar 20, 2025
31 checks passed
@NickLucche
Copy link
Contributor Author

@hyeygit Thanks for your optimization work, feel free to ping me when you have something ready

erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Apr 1, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants