-
-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[RFC]: TPU V1 Sampler planning #16268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks @NickLucche for the RFC! cc @Chenyaaang who is working on TPU structured decoding. |
Is simple max padding enough for this case? |
Each request executed might not hurt that much, if there's no recompilation. |
| Is simple max padding enough for this case? For returning the prob of the sampled token yes. But this is more about gathering the topK logprobs. Either we do a topk with some fixed/maximum K on TPU to cut down the vocab dimension and then select the actual K requested on CPU, or we compile multiple Ks. | Each request executed might not hurt that much, if there's no recompilation. I think we pay the overhead of just lunching hundreds (num_seqs) of graphs + gathering results + we lose arithm intensity on all those ops that execute on BxV shapes (now B times 1xV). Although this last point might be ~invalid if dispatching is optimized on TPU. |
It depends. We have a benchmark on the performance of transformer-like model by executing it op by op. And we can still get a 40% perf compared with executing it as a entire graph. As long as the sampler is not a performance bottleneck, flexibility is more important than performance.
Why there're two dimensions? I think the dimension we might iterate over should be the |
Yes bad naming on my side sorry. Currently we have a single |
Motivation.
I'd like to gather some input on how to move forward with sampling support, and also provide a brief recap of the current state+planned support.
At a high level, the current design splits model forward and sampling into two separate graphs.
As of now (
f2ebb6f54
) only thetemperature
andmin_p
have been intentionally enabled.As more techniques will be added, the sampling graph will grow in size (vertically, sequential ops) and performance may need monitoring, as we're simply evaluating more operations at runtime.
To clarify, even when one option is not enabled, we still evaluate a no-op version that undergoes the same ops in the graph (eg top-p with p=1).
Proposed Change.
Following #15489 a few concerns that have been raised regarding performance while enabling topk, hence adding the very first op to the initial sampling graph, I'd like to re-evaluate the current approach.
Looking at the opposite side of the spectrum one could ideally provide a sampling graph for each combination of parameters.
While this is unfeasible due to the number of parameters that sampling needs to support, one approach "in the middle" includes pre-compiling a set of common sampling params while routing requests to the "correct" one.
The main issue I see here is batching, as every request may potentially specify different sampling params, either we identify the superset for the current batch and then route to the corresponding graph, or each request is executed on a separate graph, which I believe would hurt performance even more. With that said, I still think most request will fall into the temperature only "bucket", followed by the topk/topp one, so one could implement the most popular routes. I have no production data to back this assertion with though, so don't quote me on that.
I think this is the main point to clarify before moving on and expand the number of supported parameters.
Please note the above is based on the assumption that latency is indeed going up. To clear out any such doubts, I think the PR #16022 will go a long way to allow easy benchmarking of sampling parameters.
Moving forward, I've compiled a list of parameters to support along with the effort needed to implement + my own suggestions.
We can also use it to track progress.
sampling_metadata.max_num_logprobs
: Similarly to what we do in other parts, we need to compile for differentmax_num_logprobs
values as the output is Bxmax_num_logprobs
(+there's atorch.topk
call), unless we fix it to some arbitrary value. [TPU][V1] Add support for top-logprobs #17072sampling_metadata.prompt_token_ids
for penalties. This should be fine as is given we're already compiling for different (padded) input sizes. Just it can't be optional or None/with value count as different inputs.sampling_metadata.output_token_ids
for penalties. It's already converted into a padded tensor.get_token_bin_counts_and_mask
uses a scatter_add op that may be slow on TPUnum_seqs
, which we pre-compile so they're fine.logits[logits > 0]
. We can probably replace with amasked_fill
.sampling_metadata.min_tokens
penalty must be re-implemented and vectorized (it uses a for on input dict now, graph would be dynamic). I am less familiar with this implementation so tbd.sampling_metadata.logit_bias
, current interface needs to be rethought because it can introduce dynamism. We could create a BxVmatrix (B padded+precompiled) to pack the preferences from the list[dicts]. This would work but obviously the factor of expansion can be quite
big (eg downgrading a single token would materialize a who BxV matrix). Alternatively we could provide different pre-compiled values for V (2, 4..)
at the cost of increased complexity and longer compilation time. Also, current cuda impl is highly unoptimized.
sampling_metadata.allowed_token_ids_mask
is fine as is, no effort required imo. Just it can't be None on a single graph.sampling_metadata.bad_words_token_ids
probably better to support the more generallogit_bias
option.Feedback Period.
No response
CC List.
@robertgshaw2-redhat @yaochengji @alexm-redhat @mgoin @bvrockwell @hyeygit @lsy323
Any Other Things.
UPDATE: this is very much related to #13360.
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: