Skip to content

Commit 843b7cc

Browse files
committed
Merge branch 'main' into c0sogi/main
2 parents 0d7d203 + bf0c603 commit 843b7cc

File tree

8 files changed

+85
-22
lines changed

8 files changed

+85
-22
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm
169169
## Low-level API
170170

171171
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
172-
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
172+
The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
173173

174174
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:
175175

llama_cpp/llama.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import numpy as np
3030
import numpy.typing as npt
3131

32+
from .utils import suppress_stdout_stderr
33+
3234
class BaseLlamaCache(ABC):
3335
"""Base cache class for a llama.cpp model."""
3436

@@ -227,7 +229,8 @@ def __init__(
227229
rope_freq_scale: float = 1.0,
228230
grammar: Optional[Union[str, Path]] = None,
229231
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
230-
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
232+
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
233+
mul_mat_q: Optional(bool) = None, # (TEMPORARY)
231234
verbose: bool = True,
232235
):
233236
"""Load a llama.cpp model from `model_path`.
@@ -281,7 +284,9 @@ def __init__(
281284

282285
if self.tensor_split is not None:
283286
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
284-
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
287+
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(
288+
FloatArray
289+
) # keep a reference to the array so it is not gc'd
285290
self.params.tensor_split = self._p_tensor_split
286291

287292
self.params.rope_freq_base = rope_freq_base
@@ -293,6 +298,9 @@ def __init__(
293298
if rms_norm_eps is not None:
294299
self.params.rms_norm_eps = rms_norm_eps
295300

301+
if mul_mat_q is not None:
302+
self.params.mul_mat_q = mul_mat_q
303+
296304
self.last_n_tokens_size = last_n_tokens_size
297305
self.n_batch = min(n_ctx, n_batch)
298306

@@ -310,12 +318,25 @@ def __init__(
310318
if not os.path.exists(model_path):
311319
raise ValueError(f"Model path does not exist: {model_path}")
312320

313-
self.model = llama_cpp.llama_load_model_from_file(
314-
self.model_path.encode("utf-8"), self.params
315-
)
321+
if verbose:
322+
self.model = llama_cpp.llama_load_model_from_file(
323+
self.model_path.encode("utf-8"), self.params
324+
)
325+
else:
326+
with suppress_stdout_stderr():
327+
self.model = llama_cpp.llama_load_model_from_file(
328+
self.model_path.encode("utf-8"), self.params
329+
)
316330
assert self.model is not None
317331

318-
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
332+
if verbose:
333+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
334+
else:
335+
with suppress_stdout_stderr():
336+
print("here")
337+
self.ctx = llama_cpp.llama_new_context_with_model(
338+
self.model, self.params
339+
)
319340

320341
assert self.ctx is not None
321342

@@ -986,9 +1007,7 @@ def _create_completion(
9861007
for token in remaining_tokens:
9871008
token_end_position += len(self.detokenize([token]))
9881009
# Check if stop sequence is in the token
989-
if token_end_position >= (
990-
remaining_length - first_stop_position
991-
):
1010+
if token_end_position >= (remaining_length - first_stop_position):
9921011
break
9931012
logprobs_or_none: Optional[CompletionLogprobs] = None
9941013
if logprobs is not None:
@@ -1530,10 +1549,10 @@ def create_chat_completion(
15301549
return self._convert_text_completion_to_chat(completion)
15311550

15321551
def __del__(self):
1533-
if self.model is not None:
1552+
if hasattr(self, "model") and self.model is not None:
15341553
llama_cpp.llama_free_model(self.model)
15351554
self.model = None
1536-
if self.ctx is not None:
1555+
if hasattr(self, "ctx") and self.ctx is not None:
15371556
llama_cpp.llama_free(self.ctx)
15381557
self.ctx = None
15391558

llama_cpp/llama_cpp.py

+2
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class llama_token_data_array(Structure):
181181

182182
# // Keep the booleans together to avoid misalignment during copy-by-value.
183183
# bool low_vram; // if true, reduce VRAM usage at the cost of performance
184+
# bool mul_mat_q; // if true, use experimental mul_mat_q kernels
184185
# bool f16_kv; // use fp16 for KV cache
185186
# bool logits_all; // the llama_eval() call computes all logits, not just the last one
186187
# bool vocab_only; // only load the vocabulary, no weights
@@ -203,6 +204,7 @@ class llama_context_params(Structure):
203204
("progress_callback", llama_progress_callback),
204205
("progress_callback_user_data", c_void_p),
205206
("low_vram", c_bool),
207+
("mul_mat_q", c_bool),
206208
("f16_kv", c_bool),
207209
("logits_all", c_bool),
208210
("vocab_only", c_bool),

llama_cpp/server/app.py

+4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class Settings(BaseSettings):
103103
default=None,
104104
description="TEMPORARY",
105105
)
106+
mul_mat_q: Optional[bool] = Field(
107+
default=None,
108+
description="TEMPORARY",
109+
)
106110

107111

108112
class ErrorResponse(TypedDict):

llama_cpp/utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import sys
3+
4+
5+
class suppress_stdout_stderr(object):
6+
# Oddly enough this works better than the contextlib version
7+
def __enter__(self):
8+
self.outnull_file = open(os.devnull, "w")
9+
self.errnull_file = open(os.devnull, "w")
10+
11+
self.old_stdout_fileno_undup = sys.stdout.fileno()
12+
self.old_stderr_fileno_undup = sys.stderr.fileno()
13+
14+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
15+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
16+
17+
self.old_stdout = sys.stdout
18+
self.old_stderr = sys.stderr
19+
20+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
21+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
22+
23+
sys.stdout = self.outnull_file
24+
sys.stderr = self.errnull_file
25+
return self
26+
27+
def __exit__(self, *_):
28+
sys.stdout = self.old_stdout
29+
sys.stderr = self.old_stderr
30+
31+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
32+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
33+
34+
os.close(self.old_stdout_fileno)
35+
os.close(self.old_stderr_fileno)
36+
37+
self.outnull_file.close()
38+
self.errnull_file.close()

poetry.lock

+8-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pydantic-settings = { version = ">=2.0.1", optional = true }
2525
[tool.poetry.group.dev.dependencies]
2626
black = "^23.7.0"
2727
twine = "^4.0.2"
28-
mkdocs = "^1.4.3"
28+
mkdocs = "^1.5.2"
2929
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
3030
mkdocs-material = "^9.1.21"
3131
pytest = "^7.4.0"

vendor/llama.cpp

0 commit comments

Comments
 (0)