Skip to content

[RFC]: TPU V1 Sampler planning #16268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
3 of 11 tasks
NickLucche opened this issue Apr 8, 2025 · 6 comments
Open
3 of 11 tasks

[RFC]: TPU V1 Sampler planning #16268

NickLucche opened this issue Apr 8, 2025 · 6 comments
Labels

Comments

@NickLucche
Copy link
Contributor

NickLucche commented Apr 8, 2025

Motivation.

I'd like to gather some input on how to move forward with sampling support, and also provide a brief recap of the current state+planned support.

At a high level, the current design splits model forward and sampling into two separate graphs.
As of now (f2ebb6f54) only the temperature and min_p have been intentionally enabled.
As more techniques will be added, the sampling graph will grow in size (vertically, sequential ops) and performance may need monitoring, as we're simply evaluating more operations at runtime.
To clarify, even when one option is not enabled, we still evaluate a no-op version that undergoes the same ops in the graph (eg top-p with p=1).

Proposed Change.

Following #15489 a few concerns that have been raised regarding performance while enabling topk, hence adding the very first op to the initial sampling graph, I'd like to re-evaluate the current approach.
Looking at the opposite side of the spectrum one could ideally provide a sampling graph for each combination of parameters.
While this is unfeasible due to the number of parameters that sampling needs to support, one approach "in the middle" includes pre-compiling a set of common sampling params while routing requests to the "correct" one.
The main issue I see here is batching, as every request may potentially specify different sampling params, either we identify the superset for the current batch and then route to the corresponding graph, or each request is executed on a separate graph, which I believe would hurt performance even more. With that said, I still think most request will fall into the temperature only "bucket", followed by the topk/topp one, so one could implement the most popular routes. I have no production data to back this assertion with though, so don't quote me on that.

I think this is the main point to clarify before moving on and expand the number of supported parameters.

Please note the above is based on the assumption that latency is indeed going up. To clear out any such doubts, I think the PR #16022 will go a long way to allow easy benchmarking of sampling parameters.


Moving forward, I've compiled a list of parameters to support along with the effort needed to implement + my own suggestions.
We can also use it to track progress.

  • temperature/min_p
  • topk/topp we already have an implementation for [V1][TPU] Enable Top K #15489 (topk).
  • logprobs/sampling_metadata.max_num_logprobs: Similarly to what we do in other parts, we need to compile for different max_num_logprobs values as the output is Bxmax_num_logprobs (+there's a torch.topk call), unless we fix it to some arbitrary value. [TPU][V1] Add support for top-logprobs #17072
  • sampling_metadata.prompt_token_ids for penalties. This should be fine as is given we're already compiling for different (padded) input sizes. Just it can't be optional or None/with value count as different inputs.
  • sampling_metadata.output_token_ids for penalties. It's already converted into a padded tensor.
  • penalties:
    • get_token_bin_counts_and_mask uses a scatter_add op that may be slow on TPU
    • *penalties tensors are already of shape num_seqs, which we pre-compile so they're fine.
    • there are multiple lines where the a tensor is silced with in value-dependent way (recompilation risk) logits[logits > 0]. We can probably replace with a masked_fill.
  • sampling_metadata.min_tokens penalty must be re-implemented and vectorized (it uses a for on input dict now, graph would be dynamic). I am less familiar with this implementation so tbd.
  • sampling_metadata.logit_bias, current interface needs to be rethought because it can introduce dynamism. We could create a BxV
    matrix (B padded+precompiled) to pack the preferences from the list[dicts]. This would work but obviously the factor of expansion can be quite
    big (eg downgrading a single token would materialize a who BxV matrix). Alternatively we could provide different pre-compiled values for V (2, 4..)
    at the cost of increased complexity and longer compilation time. Also, current cuda impl is highly unoptimized.
  • sampling_metadata.allowed_token_ids_mask is fine as is, no effort required imo. Just it can't be None on a single graph.
  • sampling_metadata.bad_words_token_ids probably better to support the more general logit_bias option.

Feedback Period.

No response

CC List.

@robertgshaw2-redhat @yaochengji @alexm-redhat @mgoin @bvrockwell @hyeygit @lsy323

Any Other Things.

UPDATE: this is very much related to #13360.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@NickLucche NickLucche added the RFC label Apr 8, 2025
@yaochengji
Copy link
Collaborator

Thanks @NickLucche for the RFC!

cc @Chenyaaang who is working on TPU structured decoding.

@yaochengji
Copy link
Collaborator

we need to compile for different max_num_logprobs values

Is simple max padding enough for this case?

@yaochengji
Copy link
Collaborator

each request is executed on a separate graph

Each request executed might not hurt that much, if there's no recompilation.

@NickLucche
Copy link
Contributor Author

| Is simple max padding enough for this case?

For returning the prob of the sampled token yes. But this is more about gathering the topK logprobs. Either we do a topk with some fixed/maximum K on TPU to cut down the vocab dimension and then select the actual K requested on CPU, or we compile multiple Ks.
I am fine with the former option.

| Each request executed might not hurt that much, if there's no recompilation.

I think we pay the overhead of just lunching hundreds (num_seqs) of graphs + gathering results + we lose arithm intensity on all those ops that execute on BxV shapes (now B times 1xV). Although this last point might be ~invalid if dispatching is optimized on TPU.
Would it be still viable?

@yaochengji
Copy link
Collaborator

yaochengji commented Apr 9, 2025

I think we pay the overhead of just lunching hundreds (num_seqs) of graphs + gathering results + we lose arithm intensity on all those ops

It depends. We have a benchmark on the performance of transformer-like model by executing it op by op. And we can still get a 40% perf compared with executing it as a entire graph. As long as the sampler is not a performance bottleneck, flexibility is more important than performance.

BxV

Why there're two dimensions? I think the dimension we might iterate over should be the num_reqs dimension.

@NickLucche
Copy link
Contributor Author

Why there're two dimensions? I think the dimension we might iterate over should be the num_reqs dimension.

Yes bad naming on my side sorry. Currently we have a single num_reqsxV tensor. We would need to iterate over the first dim and launch num_reqs graphs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants