-
Notifications
You must be signed in to change notification settings - Fork 12k
llama : custom attention mask + parallel decoding + no context swaps #3228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
c5df72e
tests : verify that RoPE is "additive"
ggerganov 3b4bab6
llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
ggerganov 1fb033f
ggml : ggml_rope now takes a vector with positions instead of n_past
ggerganov fad5693
metal : add rope_f16 kernel + optimize cpy kernels
ggerganov d29e769
llama : unified KV cache + batch inference API
ggerganov 58bb511
Merge branch 'master' into custom-attention-mask
ggerganov 9f42e75
llama : add new llama_decode() API that works with llama_batch
ggerganov 6952a46
llama : add cell_max heuristic for more efficient kv_cache
ggerganov 4d76d76
llama : extend llama_kv_cache API
ggerganov f015b26
llama : more robust cell_max heuristic + wip shift
ggerganov 86c90e3
metal : disable concurrency optimization
ggerganov 0cbf3bf
llama : add llama_kv_cache_shift_seq + no more context swaps
ggerganov 7c1bdd0
llama : apply K-cache roping for Falcon and Baichuan
ggerganov 1f17ea6
speculative : fix KV cache management
ggerganov 0161372
parallel : example for serving multiple users in parallel
ggerganov 466b513
parallel : disable hot-plug to avoid cache fragmentation
ggerganov 897cacc
fixes : speculative KV cache + llama worst-case graph
ggerganov fa0e677
llama : extend batch API to select which logits to output
ggerganov daf4c6d
llama : fix worst case graph build
ggerganov 7e2b997
ggml-cuda : update rope implementation for parallel decoding (#3254)
slaren 25bd254
make : add parallel to build + fix static functions in llama.cpp
ggerganov 467e307
simple : fix token counting
ggerganov 36714e1
parallel : various improvements
ggerganov ddad227
llama : fix cell_max logic + rename functions
ggerganov 806d397
parallel : try smaller batches when the KV cache is fragmented
ggerganov 16090a5
parallel : fix sequence termination criteria
ggerganov d37081a
llama : silence errors KV cache errors
ggerganov 82e20e9
parallel : remove new line from prompt
ggerganov 4b5f3cd
parallel : process system prompt once + configurable paramters + llam…
ggerganov 8a9aca3
parallel : remove question with short answers
ggerganov eed3fd4
parallel : count cache misses
ggerganov 6028879
parallel : print misses on each request
ggerganov 7b7472e
parallel : minor
ggerganov e1067ef
llama : fix n_kv to never become 0
ggerganov a1327c7
parallel : rename hot-plug to continuous-batching
ggerganov addae65
llama : improve llama_batch API + simplify parallel example
ggerganov b377bf2
simple : add parallel decoding support
ggerganov db0fc2d
simple : improve comments + free batch
ggerganov e04dc51
ggml-cuda : add rope f16, restore performance with parallel decoding …
slaren 5420696
llama : disable MPI for now
ggerganov 2f3a46f
train : make KQ_pos memory buffer permanent via dummy scale op
ggerganov 1be2b8c
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
slaren ee1d670
parallel : fix bug (extra BOS) + smaller token_prev array
ggerganov ded9b43
parallel : fix cases where the input prompts can overflow the batch
ggerganov b2debf6
parallel : add disabled experimental batch chunking in powers of two
ggerganov 5a3369d
llama : llama.h formatting + comments
ggerganov 8845160
simple : add README.md
ggerganov c1596f6
llama : fix kv cache heuristic when context is less than 32
ggerganov 2585690
Merge branch 'master' into custom-attention-mask
ggerganov 4ad0676
parallel : fix crash when `-n -1`
ggerganov e946379
llama : simplify returns if/else branches
ggerganov 4c72ab1
metal : use mm kernels for batch size > 2
ggerganov d008733
examples : utilize new llama_get_logits_ith()
ggerganov a207561
examples : add example for batched decoding
ggerganov 2b8830a
examples : do not eval prompt 2 times (close #3348)
ggerganov ce2d995
server : clear the KV cache beyond n_past before llama_decode
ggerganov c5650ed
server : avoid context swaps by shifting the KV cache
ggerganov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xaedes I'm changing the API of
ggml_rope
to take an entire vector with positions instead ofn_past
. I have a small concern about this particular change intrain-text-from-scratch
and cannot test it atm. I'm not sure if the allocator won't make some intermediate results to overwrite the data ofKQ_pos
at some point.In other places, we fix this using
ggml_allocr_alloc()
:https://github.com/ggerganov/llama.cpp/blob/1fb033fd85f8125d2830bbfe6d384be3baa17ae8/llama.cpp#L2431-L2439
But wasn't sure if it's applicable here.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.
To avoid deallocation of certain tensors T until the end of computation, I added a temporary scale_inplace(T, 1.0f) operation at the end of the computation graph before giving it to the allocator. With this the allocator cannot deallocate T before the original end of the graph. Those temporary operations are removed from the graph after allocations are done, so that they are not actually executed.
For example here: https://github.com/ggerganov/llama.cpp/blob/5ce74ee4613c06bf3391c72d7115d10726200bff/examples/finetune/finetune.cpp#L768
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, clever solution :) I added a scale op for
KQ_pos
to be safe.Btw, when
n_past == 0
, theKQ_pos
tensor would have values0, 1, 2, 3, ...
(i.e.n_past + i
).