Skip to content

Commit 73fa77d

Browse files
committed
llama.vim : accept/cancel suggestions
1 parent 474d0e6 commit 73fa77d

File tree

2 files changed

+155
-35
lines changed

2 files changed

+155
-35
lines changed

examples/llama.vim

+140-24
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,36 @@
66
"
77
"augroup llama_cpp
88
" autocmd!
9-
" autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <Esc>:call llama#fim()<CR>
9+
" autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <Esc>:call llama#fim()<CR>a
1010
"augroup END
1111
"
1212

13+
" color of the suggested text
14+
highlight llama_hint guifg=#ff772f
15+
1316
let s:default_config = {
14-
\ 'endpoint': 'http://127.0.0.1:8012/infill',
15-
\ 'prefix_lines': 32,
16-
\ 'suffix_lines': 32,
17-
\ 'n_predict': 64,
18-
\ 'n_probs': 3,
19-
\ 'temperature': 0.1,
20-
\ 'stop': ["\n"]
17+
\ 'endpoint': 'http://127.0.0.1:8012/infill',
18+
\ 'n_prefix': 32,
19+
\ 'n_suffix': 32,
20+
\ 'n_predict': 64,
21+
\ 'n_probs': 3,
22+
\ 'temperature': 0.1,
23+
\ 'stop': ["\n"]
2124
\ }
2225

2326
let g:llama_config = get(g:, 'llama_config', s:default_config)
2427

2528
function! llama#fim() abort
26-
let l:lines_prefix = getline(max([1, line('.') - g:llama_config.suffix_lines]), line('.') - 1)
27-
let l:lines_suffix = getline(line('.') + 1, min([line('$'), line('.') + g:llama_config.prefix_lines]))
29+
let l:pos_x = col('.')
30+
let l:pos_y = line('.')
31+
let l:max_y = line('$')
2832

29-
let l:cursor_col = col('.')
33+
let l:lines_prefix = getline(max([1, l:pos_y - g:llama_config.n_prefix]), l:pos_y - 1)
34+
let l:lines_suffix = getline(l:pos_y + 1, min([l:max_y, l:pos_y + g:llama_config.n_suffix]))
3035

3136
let l:line_cur = getline('.')
32-
let l:line_cur_prefix = strpart(l:line_cur, 0, l:cursor_col)
33-
let l:line_cur_suffix = strpart(l:line_cur, l:cursor_col)
37+
let l:line_cur_prefix = strpart(l:line_cur, 0, l:pos_x)
38+
let l:line_cur_suffix = strpart(l:line_cur, l:pos_x)
3439

3540
let l:prefix = ""
3641
\ . join(l:lines_prefix, "\n")
@@ -40,6 +45,7 @@ function! llama#fim() abort
4045
let l:suffix = ""
4146
\ . l:line_cur_suffix
4247
\ . join(l:lines_suffix, "\n")
48+
\ . "\n"
4349

4450
let l:request = json_encode({
4551
\ 'prompt': "",
@@ -63,21 +69,131 @@ function! llama#fim() abort
6369
\ g:llama_config.endpoint, shellescape(l:request)
6470
\ )
6571

66-
let l:response = json_decode(system(l:curl_command))
72+
let l:can_accept = v:true
73+
let s:content = []
74+
75+
let l:raw = system(l:curl_command)
76+
if l:can_accept && v:shell_error
77+
call add(s:content, "<| curl error: is the server on? |>")
78+
let l:can_accept = v:false
79+
endif
80+
81+
if l:can_accept && l:raw == ""
82+
call add(s:content, "<| empty response: is the server on? |>")
83+
let l:can_accept = v:false
84+
endif
85+
86+
" get the generated suggestion
87+
if l:can_accept
88+
let l:response = json_decode(l:raw)
89+
90+
for l:part in split(get(l:response, 'content', ''), "\n", 1)
91+
call add(s:content, l:part)
92+
endfor
93+
94+
" remove trailing new lines
95+
while len(s:content) > 0 && s:content[-1] == ""
96+
call remove(s:content, -1)
97+
endwhile
98+
endif
99+
100+
if len(s:content) == 0
101+
call add(s:content, "<| nothing to suggest |>")
102+
let l:can_accept = v:false
103+
endif
104+
105+
let s:pos_dx = len(s:content[-1])
106+
let s:content[-1] .= l:line_cur_suffix
107+
108+
" display virtual text with the suggestion
109+
let l:bufnr = bufnr('%')
110+
let s:ns_id = nvim_create_namespace('llama_virtual_text')
111+
112+
call nvim_buf_set_extmark(l:bufnr, s:ns_id, l:pos_y - 1, l:pos_x - 1, {
113+
\ 'virt_text': [[s:content[0], 'llama_hint']],
114+
\ 'virt_text_win_col': virtcol('.')
115+
\ })
116+
117+
call nvim_buf_set_extmark(l:bufnr, s:ns_id, l:pos_y - 1, 0, {
118+
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hint']]}),
119+
\ 'virt_text_win_col': virtcol('.')
120+
\ })
67121

68-
echom l:response
122+
" accept suggestion with Tab and reject it with any other key
123+
if l:can_accept
124+
inoremap <buffer> <Tab> <C-O>:call llama#accept_virtual_text()<CR>
125+
else
126+
inoremap <buffer> <Tab> <C-O>:call llama#cancel_virtual_text()<CR>
127+
endif
69128

70-
let l:content = []
71-
for l:part in split(get(l:response, 'content', ''), "\n", 1)
72-
call add(l:content, l:part)
129+
for l:key in range(33, 127) + [8, 27]
130+
if l:key != 0x7C
131+
if l:key == 8
132+
execute 'inoremap <buffer> <Bs> <C-O>:call llama#cancel_virtual_text()<CR><Bs>'
133+
elseif l:key == 27
134+
execute 'inoremap <buffer> <Esc> <C-O>:call llama#cancel_virtual_text()<CR><Esc>'
135+
elseif l:key == 127
136+
execute 'inoremap <buffer> <Del> <C-O>:call llama#cancel_virtual_text()<CR><Del>'
137+
else
138+
execute 'inoremap <buffer> ' . nr2char(l:key) . ' <C-O>:call llama#cancel_virtual_text()<CR>' . nr2char(l:key)
139+
endif
140+
endif
73141
endfor
74142

75-
echom l:content
143+
inoremap <buffer> <Up> <C-O>:call llama#cancel_virtual_text()<CR><Up>
144+
inoremap <buffer> <Down> <C-O>:call llama#cancel_virtual_text()<CR><Down>
145+
inoremap <buffer> <Left> <C-O>:call llama#cancel_virtual_text()<CR><Left>
146+
inoremap <buffer> <Right> <C-O>:call llama#cancel_virtual_text()<CR><Right>
147+
endfunction
148+
149+
function! llama#accept_virtual_text()
150+
let l:pos_x = col('.')
151+
let l:pos_y = line('.')
152+
153+
let l:line_cur = getline('.')
154+
155+
let l:pos0 = l:pos_x - 2
156+
157+
if l:pos_x == len(l:line_cur)
158+
let l:pos0 = l:pos_x - 1
159+
endif
160+
161+
" insert the suggestion at the cursor location
162+
call setline(l:pos_y, l:line_cur[:l:pos0] . s:content[0])
163+
if len(s:content) > 1
164+
call append(l:pos_y, s:content[1:-1])
165+
endif
76166

77-
" insert the 'content' at the current cursor location
78-
let l:content[0] = l:line_cur_prefix . l:content[0]
79-
let l:content[-1] .= l:line_cur_suffix
167+
" move the cursor to the end of the accepted text
168+
call cursor(l:pos_y + len(s:content) - 1, l:pos_x + s:pos_dx)
169+
170+
call llama#cancel_virtual_text()
171+
endfunction
172+
173+
function! llama#cancel_virtual_text()
174+
" clear the virtual text
175+
let l:bufnr = bufnr('%')
176+
call nvim_buf_clear_namespace(l:bufnr, s:ns_id, 0, -1)
177+
178+
" remove the mappings
179+
iunmap <buffer> <Tab>
180+
181+
for l:key in range(33, 127) + [8, 27]
182+
if l:key != 0x7C
183+
if l:key == 8
184+
execute 'iunmap <buffer> <Bs>'
185+
elseif l:key == 27
186+
execute 'iunmap <buffer> <Esc>'
187+
elseif l:key == 127
188+
execute 'iunmap <buffer> <Del>'
189+
else
190+
execute 'iunmap <buffer> ' . nr2char(l:key)
191+
endif
192+
endif
193+
endfor
80194

81-
call setline('.', l:content[0])
82-
call append (line('.'), l:content[1:-1])
195+
iunmap <buffer> <Up>
196+
iunmap <buffer> <Down>
197+
iunmap <buffer> <Left>
198+
iunmap <buffer> <Right>
83199
endfunction

src/llama-sampling.cpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -1724,24 +1724,28 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17241724
}
17251725
}
17261726

1727+
// determine the token with max logit
1728+
float l_max = -INFINITY;
1729+
int i_max = -1;
1730+
for (size_t i = 0; i < cur_p->size; ++i) {
1731+
if (cur_p->data[i].logit > l_max) {
1732+
l_max = cur_p->data[i].logit;
1733+
i_max = i;
1734+
}
1735+
}
1736+
17271737
// if all probs are -INFINITY -> reduce cur_p to single EOG token
1728-
if (std::all_of(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & td) { return td.logit == -INFINITY; })) {
1738+
if (i_max == -1) {
17291739
cur_p->size = 1;
17301740
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
17311741
cur_p->data[0].logit = 1.0f;
1732-
}
17331742

1734-
// resize
1735-
const auto size_org = cur_p->size;
1736-
1737-
cur_p->size = 0;
1738-
1739-
for (size_t i = 0; i < size_org; ++i) {
1740-
if (cur_p->data[i].logit != -INFINITY) {
1741-
cur_p->data[cur_p->size++] = cur_p->data[i];
1742-
}
1743+
return;
17431744
}
17441745

1746+
cur_p->size = 1;
1747+
cur_p->data[0] = cur_p->data[i_max];
1748+
17451749
for (size_t i = 0; i < cur_p->size; ++i) {
17461750
LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
17471751
}

0 commit comments

Comments
 (0)