Skip to content

llama : custom attention mask + parallel decoding + no context swaps #3228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c5df72e
tests : verify that RoPE is "additive"
ggerganov Sep 17, 2023
3b4bab6
llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
ggerganov Sep 17, 2023
1fb033f
ggml : ggml_rope now takes a vector with positions instead of n_past
ggerganov Sep 17, 2023
fad5693
metal : add rope_f16 kernel + optimize cpy kernels
ggerganov Sep 17, 2023
d29e769
llama : unified KV cache + batch inference API
ggerganov Sep 18, 2023
58bb511
Merge branch 'master' into custom-attention-mask
ggerganov Sep 18, 2023
9f42e75
llama : add new llama_decode() API that works with llama_batch
ggerganov Sep 18, 2023
6952a46
llama : add cell_max heuristic for more efficient kv_cache
ggerganov Sep 18, 2023
4d76d76
llama : extend llama_kv_cache API
ggerganov Sep 18, 2023
f015b26
llama : more robust cell_max heuristic + wip shift
ggerganov Sep 18, 2023
86c90e3
metal : disable concurrency optimization
ggerganov Sep 18, 2023
0cbf3bf
llama : add llama_kv_cache_shift_seq + no more context swaps
ggerganov Sep 18, 2023
7c1bdd0
llama : apply K-cache roping for Falcon and Baichuan
ggerganov Sep 18, 2023
1f17ea6
speculative : fix KV cache management
ggerganov Sep 18, 2023
0161372
parallel : example for serving multiple users in parallel
ggerganov Sep 18, 2023
466b513
parallel : disable hot-plug to avoid cache fragmentation
ggerganov Sep 18, 2023
897cacc
fixes : speculative KV cache + llama worst-case graph
ggerganov Sep 18, 2023
fa0e677
llama : extend batch API to select which logits to output
ggerganov Sep 18, 2023
daf4c6d
llama : fix worst case graph build
ggerganov Sep 19, 2023
7e2b997
ggml-cuda : update rope implementation for parallel decoding (#3254)
slaren Sep 19, 2023
25bd254
make : add parallel to build + fix static functions in llama.cpp
ggerganov Sep 19, 2023
467e307
simple : fix token counting
ggerganov Sep 19, 2023
36714e1
parallel : various improvements
ggerganov Sep 19, 2023
ddad227
llama : fix cell_max logic + rename functions
ggerganov Sep 19, 2023
806d397
parallel : try smaller batches when the KV cache is fragmented
ggerganov Sep 19, 2023
16090a5
parallel : fix sequence termination criteria
ggerganov Sep 19, 2023
d37081a
llama : silence errors KV cache errors
ggerganov Sep 19, 2023
82e20e9
parallel : remove new line from prompt
ggerganov Sep 19, 2023
4b5f3cd
parallel : process system prompt once + configurable paramters + llam…
ggerganov Sep 19, 2023
8a9aca3
parallel : remove question with short answers
ggerganov Sep 19, 2023
eed3fd4
parallel : count cache misses
ggerganov Sep 19, 2023
6028879
parallel : print misses on each request
ggerganov Sep 19, 2023
7b7472e
parallel : minor
ggerganov Sep 19, 2023
e1067ef
llama : fix n_kv to never become 0
ggerganov Sep 20, 2023
a1327c7
parallel : rename hot-plug to continuous-batching
ggerganov Sep 20, 2023
addae65
llama : improve llama_batch API + simplify parallel example
ggerganov Sep 20, 2023
b377bf2
simple : add parallel decoding support
ggerganov Sep 20, 2023
db0fc2d
simple : improve comments + free batch
ggerganov Sep 20, 2023
e04dc51
ggml-cuda : add rope f16, restore performance with parallel decoding …
slaren Sep 20, 2023
5420696
llama : disable MPI for now
ggerganov Sep 20, 2023
2f3a46f
train : make KQ_pos memory buffer permanent via dummy scale op
ggerganov Sep 20, 2023
1be2b8c
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
slaren Sep 20, 2023
ee1d670
parallel : fix bug (extra BOS) + smaller token_prev array
ggerganov Sep 20, 2023
ded9b43
parallel : fix cases where the input prompts can overflow the batch
ggerganov Sep 20, 2023
b2debf6
parallel : add disabled experimental batch chunking in powers of two
ggerganov Sep 20, 2023
5a3369d
llama : llama.h formatting + comments
ggerganov Sep 21, 2023
8845160
simple : add README.md
ggerganov Sep 21, 2023
c1596f6
llama : fix kv cache heuristic when context is less than 32
ggerganov Sep 27, 2023
2585690
Merge branch 'master' into custom-attention-mask
ggerganov Sep 28, 2023
4ad0676
parallel : fix crash when `-n -1`
ggerganov Sep 28, 2023
e946379
llama : simplify returns if/else branches
ggerganov Sep 28, 2023
4c72ab1
metal : use mm kernels for batch size > 2
ggerganov Sep 28, 2023
d008733
examples : utilize new llama_get_logits_ith()
ggerganov Sep 28, 2023
a207561
examples : add example for batched decoding
ggerganov Sep 28, 2023
2b8830a
examples : do not eval prompt 2 times (close #3348)
ggerganov Sep 28, 2023
ce2d995
server : clear the KV cache beyond n_past before llama_decode
ggerganov Sep 28, 2023
c5650ed
server : avoid context swaps by shifting the KV cache
ggerganov Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,14 @@ struct ggml_tensor * forward(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) {
Expand Down Expand Up @@ -583,8 +591,8 @@ struct ggml_tensor * forward(
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, 1]
// Kcur shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);

// store key and value to memory
{
Expand Down Expand Up @@ -810,9 +818,18 @@ struct ggml_tensor * forward_batch(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch);

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;

Expand Down Expand Up @@ -840,8 +857,8 @@ struct ggml_tensor * forward_batch(
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);

Expand Down Expand Up @@ -1100,6 +1117,14 @@ struct ggml_tensor * forward_lora(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

// inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) {
Expand Down Expand Up @@ -1133,7 +1158,7 @@ struct ggml_tensor * forward_lora(
model->layers[il].wqb,
cur)),
n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0);
KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0,
Expand All @@ -1142,7 +1167,7 @@ struct ggml_tensor * forward_lora(
model->layers[il].wkb,
cur)),
n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0);
KQ_pos, n_rot, 0, 0);

// store key and value to memory
{
Expand Down
14 changes: 11 additions & 3 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
}
};

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xaedes I'm changing the API of ggml_rope to take an entire vector with positions instead of n_past. I have a small concern about this particular change in train-text-from-scratch and cannot test it atm. I'm not sure if the allocator won't make some intermediate results to overwrite the data of KQ_pos at some point.

In other places, we fix this using ggml_allocr_alloc():

https://github.com/ggerganov/llama.cpp/blob/1fb033fd85f8125d2830bbfe6d384be3baa17ae8/llama.cpp#L2431-L2439

But wasn't sure if it's applicable here.

Copy link
Collaborator

@xaedes xaedes Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.

not sure if the allocator won't make some intermediate results to overwrite the data

To avoid deallocation of certain tensors T until the end of computation, I added a temporary scale_inplace(T, 1.0f) operation at the end of the computation graph before giving it to the allocator. With this the allocator cannot deallocate T before the original end of the graph. Those temporary operations are removed from the graph after allocations are done, so that they are not actually executed.
For example here: https://github.com/ggerganov/llama.cpp/blob/5ce74ee4613c06bf3391c72d7115d10726200bff/examples/finetune/finetune.cpp#L768

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, clever solution :) I added a scale op for KQ_pos to be safe.

During training (finetune and train-text-from-scratch) n_past is always zero, so I guess KQ_pos would always be empty.

Btw, when n_past == 0, the KQ_pos tensor would have values 0, 1, 2, 3, ... (i.e. n_past + i).

// rope has so much parameters that we make a custom function for it
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings
const int n_past = 0;
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, n_past, n_rot, rope_mode, n_ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
};

Expand Down
99 changes: 68 additions & 31 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -736,25 +736,59 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));

// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;
bool bcast_row = false;

if (ggml_nelements(src1) == ne10) {
int64_t nb = ne00;

if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
// src1 is a row
GGML_ASSERT(ne11 == 1);

nb = ne00 / 4;
[encoder setComputePipelineState:ctx->pipeline_add_row];

bcast_row = true;
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];

const int64_t n = ggml_nelements(dst)/4;
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];

if (bcast_row) {
const int64_t n = ggml_nelements(dst)/4;

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
const int nth = MIN(1024, ne0);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_MUL:
{
Expand Down Expand Up @@ -1176,7 +1210,9 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ROPE:
{
const int n_past = ((int32_t *) dst->op_params)[0];
GGML_ASSERT(ne10 == ne02);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];

Expand All @@ -1187,28 +1223,29 @@ void ggml_metal_graph_compute(

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
Expand Down
Loading