Skip to content

Commit bf0c603

Browse files
committed
Merge branch 'main' into fix-on-m1
2 parents 9f499af + 36041c8 commit bf0c603

File tree

7 files changed

+83
-22
lines changed

7 files changed

+83
-22
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm
169169
## Low-level API
170170

171171
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
172-
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
172+
The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
173173

174174
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:
175175

llama_cpp/llama.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import numpy as np
2828
import numpy.typing as npt
2929

30+
from .utils import suppress_stdout_stderr
31+
3032
class BaseLlamaCache(ABC):
3133
"""Base cache class for a llama.cpp model."""
3234

@@ -224,7 +226,8 @@ def __init__(
224226
rope_freq_base: float = 10000.0,
225227
rope_freq_scale: float = 1.0,
226228
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
227-
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
229+
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
230+
mul_mat_q: Optional(bool) = None, # (TEMPORARY)
228231
verbose: bool = True,
229232
):
230233
"""Load a llama.cpp model from `model_path`.
@@ -277,7 +280,9 @@ def __init__(
277280

278281
if self.tensor_split is not None:
279282
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
280-
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
283+
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(
284+
FloatArray
285+
) # keep a reference to the array so it is not gc'd
281286
self.params.tensor_split = self._p_tensor_split
282287

283288
self.params.rope_freq_base = rope_freq_base
@@ -289,6 +294,9 @@ def __init__(
289294
if rms_norm_eps is not None:
290295
self.params.rms_norm_eps = rms_norm_eps
291296

297+
if mul_mat_q is not None:
298+
self.params.mul_mat_q = mul_mat_q
299+
292300
self.last_n_tokens_size = last_n_tokens_size
293301
self.n_batch = min(n_ctx, n_batch)
294302

@@ -306,12 +314,25 @@ def __init__(
306314
if not os.path.exists(model_path):
307315
raise ValueError(f"Model path does not exist: {model_path}")
308316

309-
self.model = llama_cpp.llama_load_model_from_file(
310-
self.model_path.encode("utf-8"), self.params
311-
)
317+
if verbose:
318+
self.model = llama_cpp.llama_load_model_from_file(
319+
self.model_path.encode("utf-8"), self.params
320+
)
321+
else:
322+
with suppress_stdout_stderr():
323+
self.model = llama_cpp.llama_load_model_from_file(
324+
self.model_path.encode("utf-8"), self.params
325+
)
312326
assert self.model is not None
313327

314-
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
328+
if verbose:
329+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
330+
else:
331+
with suppress_stdout_stderr():
332+
print("here")
333+
self.ctx = llama_cpp.llama_new_context_with_model(
334+
self.model, self.params
335+
)
315336

316337
assert self.ctx is not None
317338

@@ -959,9 +980,7 @@ def _create_completion(
959980
for token in remaining_tokens:
960981
token_end_position += len(self.detokenize([token]))
961982
# Check if stop sequence is in the token
962-
if token_end_position >= (
963-
remaining_length - first_stop_position
964-
):
983+
if token_end_position >= (remaining_length - first_stop_position):
965984
break
966985
logprobs_or_none: Optional[CompletionLogprobs] = None
967986
if logprobs is not None:
@@ -1503,10 +1522,10 @@ def create_chat_completion(
15031522
return self._convert_text_completion_to_chat(completion)
15041523

15051524
def __del__(self):
1506-
if self.model is not None:
1525+
if hasattr(self, "model") and self.model is not None:
15071526
llama_cpp.llama_free_model(self.model)
15081527
self.model = None
1509-
if self.ctx is not None:
1528+
if hasattr(self, "ctx") and self.ctx is not None:
15101529
llama_cpp.llama_free(self.ctx)
15111530
self.ctx = None
15121531

llama_cpp/server/app.py

+4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class Settings(BaseSettings):
103103
default=None,
104104
description="TEMPORARY",
105105
)
106+
mul_mat_q: Optional[bool] = Field(
107+
default=None,
108+
description="TEMPORARY",
109+
)
106110

107111

108112
class ErrorResponse(TypedDict):

llama_cpp/utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import sys
3+
4+
5+
class suppress_stdout_stderr(object):
6+
# Oddly enough this works better than the contextlib version
7+
def __enter__(self):
8+
self.outnull_file = open(os.devnull, "w")
9+
self.errnull_file = open(os.devnull, "w")
10+
11+
self.old_stdout_fileno_undup = sys.stdout.fileno()
12+
self.old_stderr_fileno_undup = sys.stderr.fileno()
13+
14+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
15+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
16+
17+
self.old_stdout = sys.stdout
18+
self.old_stderr = sys.stderr
19+
20+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
21+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
22+
23+
sys.stdout = self.outnull_file
24+
sys.stderr = self.errnull_file
25+
return self
26+
27+
def __exit__(self, *_):
28+
sys.stdout = self.old_stdout
29+
sys.stderr = self.old_stderr
30+
31+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
32+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
33+
34+
os.close(self.old_stdout_fileno)
35+
os.close(self.old_stderr_fileno)
36+
37+
self.outnull_file.close()
38+
self.errnull_file.close()

poetry.lock

+8-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pydantic-settings = { version = ">=2.0.1", optional = true }
2525
[tool.poetry.group.dev.dependencies]
2626
black = "^23.7.0"
2727
twine = "^4.0.2"
28-
mkdocs = "^1.4.3"
28+
mkdocs = "^1.5.2"
2929
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
3030
mkdocs-material = "^9.1.21"
3131
pytest = "^7.4.0"

vendor/llama.cpp

0 commit comments

Comments
 (0)