Skip to content

Commit be7c08c

Browse files
ggerganovarthw
authored andcommitted
llama : infill sampling handle very long tokens (ggml-org#9924)
* llama : infill sampling handle very long tokens ggml-ci * cont : better indices ggml-ci
1 parent f3a2a9c commit be7c08c

File tree

4 files changed

+35
-43
lines changed

4 files changed

+35
-43
lines changed

include/llama.h

-6
Original file line numberDiff line numberDiff line change
@@ -953,12 +953,6 @@ extern "C" {
953953
int32_t lstrip,
954954
bool special);
955955

956-
// check if token0 is contained as a prefix in token1
957-
LLAMA_API bool llama_token_is_prefix(
958-
const struct llama_model * model,
959-
llama_token token0,
960-
llama_token token1);
961-
962956
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
963957
/// @param text The char pointer must be large enough to hold the resulting text.
964958
/// @return Returns the number of chars/bytes on success, no more than text_len_max.

src/llama-sampling.cpp

+35-13
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
17451745

17461746
struct llama_sampler_infill {
17471747
const struct llama_vocab * vocab;
1748+
1749+
std::vector<char> buf0;
1750+
std::vector<char> buf1;
17481751
};
17491752

17501753
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
@@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
18101813
size_t n_combined = 0; GGML_UNUSED(n_combined);
18111814

18121815
// combine tokens with common prefix
1813-
for (size_t i = 0; i < cur_p->size; ++i) {
1814-
for (size_t j = 0; j < cur_p->size; ++j) {
1815-
if (cur_p->data[i].logit == -INFINITY) {
1816+
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
1817+
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
1818+
if (cur_p->data[i0].logit == -INFINITY) {
18161819
break;
18171820
}
18181821

1819-
if (i == j || cur_p->data[j].logit == -INFINITY) {
1822+
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
18201823
continue;
18211824
}
18221825

1823-
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
1824-
if (cur_p->data[i].p > cur_p->data[j].p) {
1825-
cur_p->data[i].p += cur_p->data[j].p;
1826-
cur_p->data[j].logit = -INFINITY;
1827-
cur_p->data[j].p = 0.0f;
1828-
} else {
1829-
cur_p->data[j].p += cur_p->data[i].p;
1830-
cur_p->data[i].logit = -INFINITY;
1831-
cur_p->data[i].p = 0.0f;
1826+
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
1827+
if (len0 < 0) {
1828+
ctx->buf0.resize(len0);
1829+
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
1830+
assert(len0 > 0);
1831+
}
1832+
1833+
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
1834+
if (len1 < 0) {
1835+
ctx->buf1.resize(len1);
1836+
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
1837+
assert(len1 > 0);
1838+
}
1839+
1840+
// token i0 is a prefix of token i1
1841+
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
1842+
int dst = i0;
1843+
int src = i1;
1844+
1845+
// merge into the token with higher probability
1846+
if (cur_p->data[i1].p > cur_p->data[i0].p) {
1847+
std::swap(dst, src);
18321848
}
18331849

1850+
cur_p->data[dst].p += cur_p->data[src].p;
1851+
cur_p->data[src].logit = -INFINITY;
1852+
cur_p->data[src].p = 0.0f;
1853+
18341854
n_combined++;
18351855
}
18361856
}
@@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
19361956
/* .iface = */ &llama_sampler_infill_i,
19371957
/* .ctx = */ new llama_sampler_infill {
19381958
/* .vocab = */ &vocab,
1959+
/* .buf0 = */ std::vector<char>(512),
1960+
/* .buf1 = */ std::vector<char>(512),
19391961
},
19401962
};
19411963
}

src/llama-vocab.cpp

-17
Original file line numberDiff line numberDiff line change
@@ -1858,23 +1858,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
18581858
return 0;
18591859
}
18601860

1861-
bool llama_token_is_prefix_impl(
1862-
const struct llama_vocab & vocab,
1863-
llama_token token0,
1864-
llama_token token1) {
1865-
char text_buf_0[128];
1866-
char text_buf_1[128];
1867-
1868-
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
1869-
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
1870-
1871-
if (len0 <= 0 || len1 <= 0) {
1872-
return false;
1873-
}
1874-
1875-
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
1876-
}
1877-
18781861
int32_t llama_detokenize_impl(
18791862
const struct llama_vocab & vocab,
18801863
const llama_token * tokens,

src/llama.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -21471,13 +21471,6 @@ int32_t llama_token_to_piece(
2147121471
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
2147221472
}
2147321473

21474-
bool llama_token_is_prefix(
21475-
const struct llama_model * model,
21476-
llama_token token0,
21477-
llama_token token1) {
21478-
return llama_token_is_prefix_impl(model->vocab, token0, token1);
21479-
}
21480-
2148121474
int32_t llama_detokenize(
2148221475
const struct llama_model * model,
2148321476
const llama_token * tokens,

0 commit comments

Comments
 (0)