@@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1745
1745
1746
1746
struct llama_sampler_infill {
1747
1747
const struct llama_vocab * vocab;
1748
+
1749
+ std::vector<char > buf0;
1750
+ std::vector<char > buf1;
1748
1751
};
1749
1752
1750
1753
static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
@@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
1810
1813
size_t n_combined = 0 ; GGML_UNUSED (n_combined);
1811
1814
1812
1815
// combine tokens with common prefix
1813
- for (size_t i = 0 ; i < cur_p->size ; ++i ) {
1814
- for (size_t j = 0 ; j < cur_p->size ; ++j ) {
1815
- if (cur_p->data [i ].logit == -INFINITY) {
1816
+ for (size_t i0 = 0 ; i0 < cur_p->size ; ++i0 ) {
1817
+ for (size_t i1 = 0 ; i1 < cur_p->size ; ++i1 ) {
1818
+ if (cur_p->data [i0 ].logit == -INFINITY) {
1816
1819
break ;
1817
1820
}
1818
1821
1819
- if (i == j || cur_p->data [j ].logit == -INFINITY) {
1822
+ if (i0 == i1 || cur_p->data [i1 ].logit == -INFINITY) {
1820
1823
continue ;
1821
1824
}
1822
1825
1823
- if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1824
- if (cur_p->data [i].p > cur_p->data [j].p ) {
1825
- cur_p->data [i].p += cur_p->data [j].p ;
1826
- cur_p->data [j].logit = -INFINITY;
1827
- cur_p->data [j].p = 0 .0f ;
1828
- } else {
1829
- cur_p->data [j].p += cur_p->data [i].p ;
1830
- cur_p->data [i].logit = -INFINITY;
1831
- cur_p->data [i].p = 0 .0f ;
1826
+ int len0 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i0].id , ctx->buf0 .data (), ctx->buf0 .size (), 0 , false );
1827
+ if (len0 < 0 ) {
1828
+ ctx->buf0 .resize (len0);
1829
+ len0 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i0].id , ctx->buf0 .data (), ctx->buf0 .size (), 0 , false );
1830
+ assert (len0 > 0 );
1831
+ }
1832
+
1833
+ int len1 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i1].id , ctx->buf1 .data (), ctx->buf1 .size (), 0 , false );
1834
+ if (len1 < 0 ) {
1835
+ ctx->buf1 .resize (len1);
1836
+ len1 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i1].id , ctx->buf1 .data (), ctx->buf1 .size (), 0 , false );
1837
+ assert (len1 > 0 );
1838
+ }
1839
+
1840
+ // token i0 is a prefix of token i1
1841
+ if (len0 > 0 && len0 <= len1 && memcmp (ctx->buf0 .data (), ctx->buf1 .data (), len0) == 0 ) {
1842
+ int dst = i0;
1843
+ int src = i1;
1844
+
1845
+ // merge into the token with higher probability
1846
+ if (cur_p->data [i1].p > cur_p->data [i0].p ) {
1847
+ std::swap (dst, src);
1832
1848
}
1833
1849
1850
+ cur_p->data [dst].p += cur_p->data [src].p ;
1851
+ cur_p->data [src].logit = -INFINITY;
1852
+ cur_p->data [src].p = 0 .0f ;
1853
+
1834
1854
n_combined++;
1835
1855
}
1836
1856
}
@@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
1936
1956
/* .iface = */ &llama_sampler_infill_i,
1937
1957
/* .ctx = */ new llama_sampler_infill {
1938
1958
/* .vocab = */ &vocab,
1959
+ /* .buf0 = */ std::vector<char >(512 ),
1960
+ /* .buf1 = */ std::vector<char >(512 ),
1939
1961
},
1940
1962
};
1941
1963
}
0 commit comments