Skip to content

Commit c088a2b

Browse files
committed
Un-skip tests
1 parent bf3d0dc commit c088a2b

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/test_llama.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import llama_cpp
32

43
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
@@ -15,15 +14,20 @@ def test_llama():
1514
assert llama.detokenize(llama.tokenize(text)) == text
1615

1716

18-
@pytest.mark.skip(reason="need to update sample mocking")
17+
# @pytest.mark.skip(reason="need to update sample mocking")
1918
def test_llama_patch(monkeypatch):
2019
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
20+
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
2121

2222
## Set up mock function
2323
def mock_eval(*args, **kwargs):
2424
return 0
25+
26+
def mock_get_logits(*args, **kwargs):
27+
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
2528

2629
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
30+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
2731

2832
output_text = " jumps over the lazy dog."
2933
output_tokens = llama.tokenize(output_text.encode("utf-8"))
@@ -38,7 +42,7 @@ def mock_sample(*args, **kwargs):
3842
else:
3943
return token_eos
4044

41-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
45+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
4246

4347
text = "The quick brown fox"
4448

@@ -97,15 +101,19 @@ def test_llama_pickle():
97101

98102
assert llama.detokenize(llama.tokenize(text)) == text
99103

100-
@pytest.mark.skip(reason="need to update sample mocking")
101104
def test_utf8(monkeypatch):
102105
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
106+
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
103107

104108
## Set up mock function
105109
def mock_eval(*args, **kwargs):
106110
return 0
107111

112+
def mock_get_logits(*args, **kwargs):
113+
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
114+
108115
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
116+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
109117

110118
output_text = "😀"
111119
output_tokens = llama.tokenize(output_text.encode("utf-8"))
@@ -120,7 +128,7 @@ def mock_sample(*args, **kwargs):
120128
else:
121129
return token_eos
122130

123-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
131+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
124132

125133
## Test basic completion with utf8 multibyte
126134
n = 0 # reset

0 commit comments

Comments
 (0)