1
- import pytest
2
1
import llama_cpp
3
2
4
3
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
@@ -15,15 +14,20 @@ def test_llama():
15
14
assert llama .detokenize (llama .tokenize (text )) == text
16
15
17
16
18
- @pytest .mark .skip (reason = "need to update sample mocking" )
17
+ # @pytest.mark.skip(reason="need to update sample mocking")
19
18
def test_llama_patch (monkeypatch ):
20
19
llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True )
20
+ n_vocab = int (llama_cpp .llama_n_vocab (llama .ctx ))
21
21
22
22
## Set up mock function
23
23
def mock_eval (* args , ** kwargs ):
24
24
return 0
25
+
26
+ def mock_get_logits (* args , ** kwargs ):
27
+ return (llama_cpp .c_float * n_vocab )(* [llama_cpp .c_float (0 ) for _ in range (n_vocab )])
25
28
26
29
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_eval" , mock_eval )
30
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
27
31
28
32
output_text = " jumps over the lazy dog."
29
33
output_tokens = llama .tokenize (output_text .encode ("utf-8" ))
@@ -38,7 +42,7 @@ def mock_sample(*args, **kwargs):
38
42
else :
39
43
return token_eos
40
44
41
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_top_p_top_k " , mock_sample )
45
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token " , mock_sample )
42
46
43
47
text = "The quick brown fox"
44
48
@@ -97,15 +101,19 @@ def test_llama_pickle():
97
101
98
102
assert llama .detokenize (llama .tokenize (text )) == text
99
103
100
- @pytest .mark .skip (reason = "need to update sample mocking" )
101
104
def test_utf8 (monkeypatch ):
102
105
llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True )
106
+ n_vocab = int (llama_cpp .llama_n_vocab (llama .ctx ))
103
107
104
108
## Set up mock function
105
109
def mock_eval (* args , ** kwargs ):
106
110
return 0
107
111
112
+ def mock_get_logits (* args , ** kwargs ):
113
+ return (llama_cpp .c_float * n_vocab )(* [llama_cpp .c_float (0 ) for _ in range (n_vocab )])
114
+
108
115
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_eval" , mock_eval )
116
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
109
117
110
118
output_text = "😀"
111
119
output_tokens = llama .tokenize (output_text .encode ("utf-8" ))
@@ -120,7 +128,7 @@ def mock_sample(*args, **kwargs):
120
128
else :
121
129
return token_eos
122
130
123
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_top_p_top_k " , mock_sample )
131
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token " , mock_sample )
124
132
125
133
## Test basic completion with utf8 multibyte
126
134
n = 0 # reset
0 commit comments