@@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1739
1739
};
1740
1740
}
1741
1741
1742
+ // infill
1743
+
1744
+ // #define GGML_DEBUG_SAMPLER_INFILL
1745
+
1746
+ struct llama_sampler_infill {
1747
+ const struct llama_vocab * vocab;
1748
+ };
1749
+
1750
+ static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
1751
+ return " infill" ;
1752
+ }
1753
+
1754
+ static void llama_sampler_infill_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1755
+ auto * ctx = (llama_sampler_infill *) smpl->ctx ;
1756
+
1757
+ llama_sampler_softmax_impl (cur_p);
1758
+
1759
+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
1760
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
1761
+ #else
1762
+ #define LOG_DBG_CUR (...)
1763
+ #endif
1764
+
1765
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1766
+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1767
+ }
1768
+
1769
+ float p_txt_sum = 0 .0f ;
1770
+ float p_eog_sum = 0 .0f ;
1771
+
1772
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1773
+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1774
+ p_eog_sum += cur_p->data [i].p ;
1775
+ } else {
1776
+ p_txt_sum += cur_p->data [i].p ;
1777
+ }
1778
+ }
1779
+
1780
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED (rat);
1781
+
1782
+ LOG_DBG_CUR (" %s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , __func__, p_txt_sum, p_eog_sum, rat, cur_p->size );
1783
+
1784
+ if (3 *p_eog_sum*cur_p->size > p_txt_sum) {
1785
+ LOG_DBG_CUR (" %s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , __func__, p_txt_sum/p_eog_sum);
1786
+
1787
+ // keep just the EOG tokens
1788
+ const auto size_org = cur_p->size ;
1789
+
1790
+ cur_p->size = 0 ;
1791
+
1792
+ float p_sum = 0 .0f ;
1793
+
1794
+ for (size_t i = 0 ; i < size_org; ++i) {
1795
+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1796
+ p_sum += cur_p->data [i].p ;
1797
+
1798
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1799
+ }
1800
+ }
1801
+
1802
+ // normalize probs
1803
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1804
+ cur_p->data [i].p /= p_sum;
1805
+ }
1806
+
1807
+ return ;
1808
+ }
1809
+
1810
+ size_t n_combined = 0 ; GGML_UNUSED (n_combined);
1811
+
1812
+ // combine tokens with common prefix
1813
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1814
+ for (size_t j = 0 ; j < cur_p->size ; ++j) {
1815
+ if (cur_p->data [i].logit == -INFINITY) {
1816
+ break ;
1817
+ }
1818
+
1819
+ if (i == j || cur_p->data [j].logit == -INFINITY) {
1820
+ continue ;
1821
+ }
1822
+
1823
+ if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1824
+ if (cur_p->data [i].p > cur_p->data [j].p ) {
1825
+ cur_p->data [i].p += cur_p->data [j].p ;
1826
+ cur_p->data [j].logit = -INFINITY;
1827
+ cur_p->data [j].p = 0 .0f ;
1828
+ } else {
1829
+ cur_p->data [j].p += cur_p->data [i].p ;
1830
+ cur_p->data [i].logit = -INFINITY;
1831
+ cur_p->data [i].p = 0 .0f ;
1832
+ }
1833
+
1834
+ n_combined++;
1835
+ }
1836
+ }
1837
+ }
1838
+
1839
+ size_t n_non_eog = 0 ;
1840
+
1841
+ size_t size_org = cur_p->size ;
1842
+
1843
+ float p_sum = 0 .0f ;
1844
+ float thold = 0 .2f ;
1845
+
1846
+ cur_p->size = 0 ;
1847
+
1848
+ LOG_DBG_CUR (" %s: n_combined = %zu, applying thold = %.3f\n " , __func__, n_combined, thold);
1849
+
1850
+ for (size_t i = 0 ; i < size_org; ++i) {
1851
+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1852
+
1853
+ if (cur_p->data [i].p < thold && !is_eog) {
1854
+ continue ;
1855
+ }
1856
+
1857
+ if (!is_eog) {
1858
+ ++n_non_eog;
1859
+ }
1860
+
1861
+ p_sum += cur_p->data [i].p ;
1862
+
1863
+ // keep this token
1864
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1865
+ }
1866
+
1867
+ LOG_DBG_CUR (" %s: n_non_eog = %zu\n " , __func__, n_non_eog);
1868
+
1869
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
1870
+ if (n_non_eog == 0 ) {
1871
+ cur_p->size = 1 ;
1872
+ cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
1873
+ cur_p->data [0 ].logit = 1 .0f ;
1874
+
1875
+ return ;
1876
+ }
1877
+
1878
+ // normalize probs
1879
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1880
+ cur_p->data [i].p /= p_sum;
1881
+
1882
+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1883
+ }
1884
+
1885
+ size_org = cur_p->size ;
1886
+ p_sum = 0 .0f ;
1887
+ thold = 1.0 /(n_non_eog + 1 );
1888
+
1889
+ cur_p->size = 0 ;
1890
+
1891
+ LOG_DBG_CUR (" %s: applying thold = %.3f\n " , __func__, thold);
1892
+
1893
+ for (size_t i = 0 ; i < size_org; ++i) {
1894
+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1895
+
1896
+ if (cur_p->data [i].p < thold && !is_eog) {
1897
+ continue ;
1898
+ }
1899
+
1900
+ p_sum += cur_p->data [i].p ;
1901
+
1902
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1903
+ }
1904
+
1905
+ // normalize probs
1906
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1907
+ cur_p->data [i].p /= p_sum;
1908
+
1909
+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1910
+ }
1911
+
1912
+ #undef LOG_DBG_CUR
1913
+ }
1914
+
1915
+ static struct llama_sampler * llama_sampler_infill_clone (const struct llama_sampler * smpl) {
1916
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx ;
1917
+ return llama_sampler_init_infill_impl (*ctx->vocab );
1918
+ }
1919
+
1920
+ static void llama_sampler_infill_free (struct llama_sampler * smpl) {
1921
+ delete (llama_sampler_infill *) smpl->ctx ;
1922
+ }
1923
+
1924
+ static struct llama_sampler_i llama_sampler_infill_i = {
1925
+ /* .name = */ llama_sampler_infill_name,
1926
+ /* .accept = */ nullptr ,
1927
+ /* .apply = */ llama_sampler_infill_apply,
1928
+ /* .reset = */ nullptr ,
1929
+ /* .clone = */ llama_sampler_infill_clone,
1930
+ /* .free = */ llama_sampler_infill_free,
1931
+ };
1932
+
1933
+ struct llama_sampler * llama_sampler_init_infill_impl (
1934
+ const struct llama_vocab & vocab) {
1935
+ return new llama_sampler {
1936
+ /* .iface = */ &llama_sampler_infill_i,
1937
+ /* .ctx = */ new llama_sampler_infill {
1938
+ /* .vocab = */ &vocab,
1939
+ },
1940
+ };
1941
+ }
1942
+
1742
1943
// utils
1743
1944
1744
1945
uint32_t llama_sampler_get_seed (const struct llama_sampler * smpl) {
0 commit comments