Skip to content

Commit 06e841b

Browse files
ggerganovdrollings
authored andcommitted
llama : add infill sampler (ggml-org#9896)
ggml-ci
1 parent ed0843d commit 06e841b

File tree

9 files changed

+294
-23
lines changed

9 files changed

+294
-23
lines changed

common/common.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ enum common_sampler_type {
9191
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
9292
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
9393
COMMON_SAMPLER_TYPE_XTC = 7,
94-
94+
COMMON_SAMPLER_TYPE_INFILL = 8,
9595
};
9696

9797
// dimensionality reduction methods, used by cvector-generator
@@ -136,7 +136,7 @@ struct common_sampler_params {
136136
COMMON_SAMPLER_TYPE_TOP_P,
137137
COMMON_SAMPLER_TYPE_MIN_P,
138138
COMMON_SAMPLER_TYPE_XTC,
139-
COMMON_SAMPLER_TYPE_TEMPERATURE
139+
COMMON_SAMPLER_TYPE_TEMPERATURE,
140140
};
141141

142142
std::string grammar; // optional BNF-like grammar to constrain sampling

common/sampling.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
196196
case COMMON_SAMPLER_TYPE_TEMPERATURE:
197197
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
198198
break;
199+
case COMMON_SAMPLER_TYPE_INFILL:
200+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
201+
break;
199202
default:
200203
GGML_ASSERT(false && "unknown sampler type");
201204
}
@@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
376379
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
377380
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
378381
case COMMON_SAMPLER_TYPE_XTC: return 'x';
382+
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
379383
default : return '?';
380384
}
381385
}
@@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
389393
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
390394
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
391395
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
396+
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
392397
default : return "";
393398
}
394399
}
@@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
402407
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
403408
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
404409
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
410+
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
405411
};
406412

407413
// since samplers names are written multiple ways
@@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
448454
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
449455
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
450456
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
451-
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
457+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
458+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
452459
};
453460

454461
std::vector<common_sampler_type> samplers;

examples/main/main.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
569569
if (!params.ctx_shift){
570570
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
571571
break;
572-
} else {
573-
if (params.n_predict == -2) {
574-
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
575-
break;
576-
}
572+
}
573+
574+
if (params.n_predict == -2) {
575+
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
576+
break;
577+
}
577578

578-
const int n_left = n_past - params.n_keep;
579-
const int n_discard = n_left/2;
579+
const int n_left = n_past - params.n_keep;
580+
const int n_discard = n_left/2;
580581

581-
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
582-
n_past, n_left, n_ctx, params.n_keep, n_discard);
582+
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
583+
n_past, n_left, n_ctx, params.n_keep, n_discard);
583584

584-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
585-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
585+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
586+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
586587

587-
n_past -= n_discard;
588+
n_past -= n_discard;
588589

589-
LOG_DBG("after swap: n_past = %d\n", n_past);
590+
LOG_DBG("after swap: n_past = %d\n", n_past);
590591

591-
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
592+
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
592593

593-
LOG_DBG("clear session path\n");
594-
path_session.clear();
595-
}
594+
LOG_DBG("clear session path\n");
595+
path_session.clear();
596596
}
597597
} else {
598598
// context extension via Self-Extend

include/llama.h

+28
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,12 @@ extern "C" {
953953
int32_t lstrip,
954954
bool special);
955955

956+
// check if token0 is contained as a prefix in token1
957+
LLAMA_API bool llama_token_is_prefix(
958+
const struct llama_model * model,
959+
llama_token token0,
960+
llama_token token1);
961+
956962
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
957963
/// @param text The char pointer must be large enough to hold the resulting text.
958964
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
@@ -1148,6 +1154,28 @@ extern "C" {
11481154
int32_t n_logit_bias,
11491155
const llama_logit_bias * logit_bias);
11501156

1157+
// this sampler is meant to be used for fill-in-the-middle infilling
1158+
// it's supposed to be used after top_k + top_p sampling
1159+
//
1160+
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
1161+
// 2. combine probs of tokens that have the same prefix
1162+
//
1163+
// example:
1164+
//
1165+
// - before:
1166+
// "hel": 0.5
1167+
// "hell": 0.2
1168+
// "hello": 0.1
1169+
// "dummy": 0.1
1170+
//
1171+
// - after:
1172+
// "hel": 0.8
1173+
// "dummy": 0.1
1174+
//
1175+
// 3. discard non-EOG tokens with low prob
1176+
// 4. if no tokens are left -> pick EOT
1177+
//
1178+
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
11511179

11521180
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
11531181
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

src/llama-sampling.cpp

+201
Original file line numberDiff line numberDiff line change
@@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
17391739
};
17401740
}
17411741

1742+
// infill
1743+
1744+
//#define GGML_DEBUG_SAMPLER_INFILL
1745+
1746+
struct llama_sampler_infill {
1747+
const struct llama_vocab * vocab;
1748+
};
1749+
1750+
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
1751+
return "infill";
1752+
}
1753+
1754+
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1755+
auto * ctx = (llama_sampler_infill *) smpl->ctx;
1756+
1757+
llama_sampler_softmax_impl(cur_p);
1758+
1759+
#if defined(GGML_DEBUG_SAMPLER_INFILL)
1760+
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
1761+
#else
1762+
#define LOG_DBG_CUR(...)
1763+
#endif
1764+
1765+
for (size_t i = 0; i < cur_p->size; ++i) {
1766+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1767+
}
1768+
1769+
float p_txt_sum = 0.0f;
1770+
float p_eog_sum = 0.0f;
1771+
1772+
for (size_t i = 0; i < cur_p->size; ++i) {
1773+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1774+
p_eog_sum += cur_p->data[i].p;
1775+
} else {
1776+
p_txt_sum += cur_p->data[i].p;
1777+
}
1778+
}
1779+
1780+
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
1781+
1782+
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
1783+
1784+
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
1785+
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
1786+
1787+
// keep just the EOG tokens
1788+
const auto size_org = cur_p->size;
1789+
1790+
cur_p->size = 0;
1791+
1792+
float p_sum = 0.0f;
1793+
1794+
for (size_t i = 0; i < size_org; ++i) {
1795+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1796+
p_sum += cur_p->data[i].p;
1797+
1798+
cur_p->data[cur_p->size++] = cur_p->data[i];
1799+
}
1800+
}
1801+
1802+
// normalize probs
1803+
for (size_t i = 0; i < cur_p->size; ++i) {
1804+
cur_p->data[i].p /= p_sum;
1805+
}
1806+
1807+
return;
1808+
}
1809+
1810+
size_t n_combined = 0; GGML_UNUSED(n_combined);
1811+
1812+
// combine tokens with common prefix
1813+
for (size_t i = 0; i < cur_p->size; ++i) {
1814+
for (size_t j = 0; j < cur_p->size; ++j) {
1815+
if (cur_p->data[i].logit == -INFINITY) {
1816+
break;
1817+
}
1818+
1819+
if (i == j || cur_p->data[j].logit == -INFINITY) {
1820+
continue;
1821+
}
1822+
1823+
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
1824+
if (cur_p->data[i].p > cur_p->data[j].p) {
1825+
cur_p->data[i].p += cur_p->data[j].p;
1826+
cur_p->data[j].logit = -INFINITY;
1827+
cur_p->data[j].p = 0.0f;
1828+
} else {
1829+
cur_p->data[j].p += cur_p->data[i].p;
1830+
cur_p->data[i].logit = -INFINITY;
1831+
cur_p->data[i].p = 0.0f;
1832+
}
1833+
1834+
n_combined++;
1835+
}
1836+
}
1837+
}
1838+
1839+
size_t n_non_eog = 0;
1840+
1841+
size_t size_org = cur_p->size;
1842+
1843+
float p_sum = 0.0f;
1844+
float thold = 0.2f;
1845+
1846+
cur_p->size = 0;
1847+
1848+
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
1849+
1850+
for (size_t i = 0; i < size_org; ++i) {
1851+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
1852+
1853+
if (cur_p->data[i].p < thold && !is_eog) {
1854+
continue;
1855+
}
1856+
1857+
if (!is_eog) {
1858+
++n_non_eog;
1859+
}
1860+
1861+
p_sum += cur_p->data[i].p;
1862+
1863+
// keep this token
1864+
cur_p->data[cur_p->size++] = cur_p->data[i];
1865+
}
1866+
1867+
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
1868+
1869+
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
1870+
if (n_non_eog == 0) {
1871+
cur_p->size = 1;
1872+
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
1873+
cur_p->data[0].logit = 1.0f;
1874+
1875+
return;
1876+
}
1877+
1878+
// normalize probs
1879+
for (size_t i = 0; i < cur_p->size; ++i) {
1880+
cur_p->data[i].p /= p_sum;
1881+
1882+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1883+
}
1884+
1885+
size_org = cur_p->size;
1886+
p_sum = 0.0f;
1887+
thold = 1.0/(n_non_eog + 1);
1888+
1889+
cur_p->size = 0;
1890+
1891+
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
1892+
1893+
for (size_t i = 0; i < size_org; ++i) {
1894+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
1895+
1896+
if (cur_p->data[i].p < thold && !is_eog) {
1897+
continue;
1898+
}
1899+
1900+
p_sum += cur_p->data[i].p;
1901+
1902+
cur_p->data[cur_p->size++] = cur_p->data[i];
1903+
}
1904+
1905+
// normalize probs
1906+
for (size_t i = 0; i < cur_p->size; ++i) {
1907+
cur_p->data[i].p /= p_sum;
1908+
1909+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1910+
}
1911+
1912+
#undef LOG_DBG_CUR
1913+
}
1914+
1915+
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
1916+
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
1917+
return llama_sampler_init_infill_impl(*ctx->vocab);
1918+
}
1919+
1920+
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
1921+
delete (llama_sampler_infill *) smpl->ctx;
1922+
}
1923+
1924+
static struct llama_sampler_i llama_sampler_infill_i = {
1925+
/* .name = */ llama_sampler_infill_name,
1926+
/* .accept = */ nullptr,
1927+
/* .apply = */ llama_sampler_infill_apply,
1928+
/* .reset = */ nullptr,
1929+
/* .clone = */ llama_sampler_infill_clone,
1930+
/* .free = */ llama_sampler_infill_free,
1931+
};
1932+
1933+
struct llama_sampler * llama_sampler_init_infill_impl(
1934+
const struct llama_vocab & vocab) {
1935+
return new llama_sampler {
1936+
/* .iface = */ &llama_sampler_infill_i,
1937+
/* .ctx = */ new llama_sampler_infill {
1938+
/* .vocab = */ &vocab,
1939+
},
1940+
};
1941+
}
1942+
17421943
// utils
17431944

17441945
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {

src/llama-sampling.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
#include "llama-grammar.h"
66

7-
#include <unordered_map>
8-
97
struct llama_vocab;
108
struct llama_grammar;
119

@@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
2725
const struct llama_vocab & vocab,
2826
const char * grammar_str,
2927
const char * grammar_root);
28+
29+
struct llama_sampler * llama_sampler_init_infill_impl(
30+
const struct llama_vocab & vocab);

src/llama-vocab.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
18581858
return 0;
18591859
}
18601860

1861+
bool llama_token_is_prefix_impl(
1862+
const struct llama_vocab & vocab,
1863+
llama_token token0,
1864+
llama_token token1) {
1865+
char text_buf_0[128];
1866+
char text_buf_1[128];
1867+
1868+
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
1869+
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
1870+
1871+
if (len0 <= 0 || len1 <= 0) {
1872+
return false;
1873+
}
1874+
1875+
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
1876+
}
1877+
18611878
int32_t llama_detokenize_impl(
18621879
const struct llama_vocab & vocab,
18631880
const llama_token * tokens,

0 commit comments

Comments
 (0)