Merge branch 'main' into c0sogi/main
This commit is contained in:
commit
843b7ccd90
8 changed files with 85 additions and 22 deletions
|
@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm
|
|||
## Low-level API
|
||||
|
||||
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
|
||||
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
|
||||
The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
|
||||
|
||||
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ from .llama_grammar import LlamaGrammar
|
|||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from .utils import suppress_stdout_stderr
|
||||
|
||||
class BaseLlamaCache(ABC):
|
||||
"""Base cache class for a llama.cpp model."""
|
||||
|
||||
|
@ -227,7 +229,8 @@ class Llama:
|
|||
rope_freq_scale: float = 1.0,
|
||||
grammar: Optional[Union[str, Path]] = None,
|
||||
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
||||
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
||||
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
||||
mul_mat_q: Optional(bool) = None, # (TEMPORARY)
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Load a llama.cpp model from `model_path`.
|
||||
|
@ -281,7 +284,9 @@ class Llama:
|
|||
|
||||
if self.tensor_split is not None:
|
||||
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
|
||||
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
|
||||
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(
|
||||
FloatArray
|
||||
) # keep a reference to the array so it is not gc'd
|
||||
self.params.tensor_split = self._p_tensor_split
|
||||
|
||||
self.params.rope_freq_base = rope_freq_base
|
||||
|
@ -293,6 +298,9 @@ class Llama:
|
|||
if rms_norm_eps is not None:
|
||||
self.params.rms_norm_eps = rms_norm_eps
|
||||
|
||||
if mul_mat_q is not None:
|
||||
self.params.mul_mat_q = mul_mat_q
|
||||
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
|
||||
|
@ -310,12 +318,25 @@ class Llama:
|
|||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"Model path does not exist: {model_path}")
|
||||
|
||||
self.model = llama_cpp.llama_load_model_from_file(
|
||||
self.model_path.encode("utf-8"), self.params
|
||||
)
|
||||
if verbose:
|
||||
self.model = llama_cpp.llama_load_model_from_file(
|
||||
self.model_path.encode("utf-8"), self.params
|
||||
)
|
||||
else:
|
||||
with suppress_stdout_stderr():
|
||||
self.model = llama_cpp.llama_load_model_from_file(
|
||||
self.model_path.encode("utf-8"), self.params
|
||||
)
|
||||
assert self.model is not None
|
||||
|
||||
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
|
||||
if verbose:
|
||||
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
|
||||
else:
|
||||
with suppress_stdout_stderr():
|
||||
print("here")
|
||||
self.ctx = llama_cpp.llama_new_context_with_model(
|
||||
self.model, self.params
|
||||
)
|
||||
|
||||
assert self.ctx is not None
|
||||
|
||||
|
@ -986,9 +1007,7 @@ class Llama:
|
|||
for token in remaining_tokens:
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position >= (
|
||||
remaining_length - first_stop_position
|
||||
):
|
||||
if token_end_position >= (remaining_length - first_stop_position):
|
||||
break
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
|
@ -1530,10 +1549,10 @@ class Llama:
|
|||
return self._convert_text_completion_to_chat(completion)
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if hasattr(self, "model") and self.model is not None:
|
||||
llama_cpp.llama_free_model(self.model)
|
||||
self.model = None
|
||||
if self.ctx is not None:
|
||||
if hasattr(self, "ctx") and self.ctx is not None:
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
|
||||
|
|
|
@ -181,6 +181,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
|
|||
|
||||
# // Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
# bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
# bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||
# bool f16_kv; // use fp16 for KV cache
|
||||
# bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
# bool vocab_only; // only load the vocabulary, no weights
|
||||
|
@ -203,6 +204,7 @@ class llama_context_params(Structure):
|
|||
("progress_callback", llama_progress_callback),
|
||||
("progress_callback_user_data", c_void_p),
|
||||
("low_vram", c_bool),
|
||||
("mul_mat_q", c_bool),
|
||||
("f16_kv", c_bool),
|
||||
("logits_all", c_bool),
|
||||
("vocab_only", c_bool),
|
||||
|
|
|
@ -103,6 +103,10 @@ class Settings(BaseSettings):
|
|||
default=None,
|
||||
description="TEMPORARY",
|
||||
)
|
||||
mul_mat_q: Optional[bool] = Field(
|
||||
default=None,
|
||||
description="TEMPORARY",
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse(TypedDict):
|
||||
|
|
38
llama_cpp/utils.py
Normal file
38
llama_cpp/utils.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
|
||||
class suppress_stdout_stderr(object):
|
||||
# Oddly enough this works better than the contextlib version
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
16
poetry.lock
generated
16
poetry.lock
generated
|
@ -384,17 +384,17 @@ test = ["pytest (>=6)"]
|
|||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.100.1"
|
||||
version = "0.101.0"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "fastapi-0.100.1-py3-none-any.whl", hash = "sha256:ec6dd52bfc4eff3063cfcd0713b43c87640fefb2687bbbe3d8a08d94049cdf32"},
|
||||
{file = "fastapi-0.100.1.tar.gz", hash = "sha256:522700d7a469e4a973d92321ab93312448fbe20fca9c8da97effc7e7bc56df23"},
|
||||
{file = "fastapi-0.101.0-py3-none-any.whl", hash = "sha256:494eb3494d89e8079c20859d7ca695f66eaccc40f46fe8c75ab6186d15f05ffd"},
|
||||
{file = "fastapi-0.101.0.tar.gz", hash = "sha256:ca2ae65fe42f6a34b5cf6c994337149154b1b400c39809d7b2dccdceb5ae77af"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<3.0.0"
|
||||
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
|
||||
starlette = ">=0.27.0,<0.28.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
|
@ -744,13 +744,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "mkdocs"
|
||||
version = "1.5.1"
|
||||
version = "1.5.2"
|
||||
description = "Project documentation with Markdown."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mkdocs-1.5.1-py3-none-any.whl", hash = "sha256:67e889f8d8ba1fe5decdfc59f5f8f21d6a8925a129339e93dede303bdea03a98"},
|
||||
{file = "mkdocs-1.5.1.tar.gz", hash = "sha256:f2f323c62fffdf1b71b84849e39aef56d6852b3f0a5571552bca32cefc650209"},
|
||||
{file = "mkdocs-1.5.2-py3-none-any.whl", hash = "sha256:60a62538519c2e96fe8426654a67ee177350451616118a41596ae7c876bb7eac"},
|
||||
{file = "mkdocs-1.5.2.tar.gz", hash = "sha256:70d0da09c26cff288852471be03c23f0f521fc15cf16ac89c7a3bfb9ae8d24f9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1757,4 +1757,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8.1"
|
||||
content-hash = "6718d680fa89f9518a232c1110ba43958d3e21c54c4dbd9129effa4f40a02b81"
|
||||
content-hash = "4bfb67dfb72b02c845376211f7f958b2ece8c985944fbd03d246c858e846ddf6"
|
||||
|
|
|
@ -25,7 +25,7 @@ pydantic-settings = { version = ">=2.0.1", optional = true }
|
|||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.7.0"
|
||||
twine = "^4.0.2"
|
||||
mkdocs = "^1.4.3"
|
||||
mkdocs = "^1.5.2"
|
||||
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
|
||||
mkdocs-material = "^9.1.21"
|
||||
pytest = "^7.4.0"
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 41c674161fb2459bdf7806d1eebead15bc5d046e
|
||||
Subproject commit f5bfea0580e417f99850d5456ca541d871a3e48c
|
Loading…
Reference in a new issue