Merge branch 'main' into fix-on-m1
This commit is contained in:
commit
bf0c603c51
7 changed files with 83 additions and 22 deletions
|
@ -169,7 +169,7 @@ docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/ggm
|
||||||
## Low-level API
|
## Low-level API
|
||||||
|
|
||||||
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
|
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
|
||||||
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
|
The entire low-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
|
||||||
|
|
||||||
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:
|
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,8 @@ from .llama_types import *
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from .utils import suppress_stdout_stderr
|
||||||
|
|
||||||
class BaseLlamaCache(ABC):
|
class BaseLlamaCache(ABC):
|
||||||
"""Base cache class for a llama.cpp model."""
|
"""Base cache class for a llama.cpp model."""
|
||||||
|
|
||||||
|
@ -225,6 +227,7 @@ class Llama:
|
||||||
rope_freq_scale: float = 1.0,
|
rope_freq_scale: float = 1.0,
|
||||||
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
||||||
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
||||||
|
mul_mat_q: Optional(bool) = None, # (TEMPORARY)
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
):
|
):
|
||||||
"""Load a llama.cpp model from `model_path`.
|
"""Load a llama.cpp model from `model_path`.
|
||||||
|
@ -277,7 +280,9 @@ class Llama:
|
||||||
|
|
||||||
if self.tensor_split is not None:
|
if self.tensor_split is not None:
|
||||||
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
|
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
|
||||||
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
|
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(
|
||||||
|
FloatArray
|
||||||
|
) # keep a reference to the array so it is not gc'd
|
||||||
self.params.tensor_split = self._p_tensor_split
|
self.params.tensor_split = self._p_tensor_split
|
||||||
|
|
||||||
self.params.rope_freq_base = rope_freq_base
|
self.params.rope_freq_base = rope_freq_base
|
||||||
|
@ -289,6 +294,9 @@ class Llama:
|
||||||
if rms_norm_eps is not None:
|
if rms_norm_eps is not None:
|
||||||
self.params.rms_norm_eps = rms_norm_eps
|
self.params.rms_norm_eps = rms_norm_eps
|
||||||
|
|
||||||
|
if mul_mat_q is not None:
|
||||||
|
self.params.mul_mat_q = mul_mat_q
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
|
|
||||||
|
@ -306,12 +314,25 @@ class Llama:
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise ValueError(f"Model path does not exist: {model_path}")
|
raise ValueError(f"Model path does not exist: {model_path}")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
self.model = llama_cpp.llama_load_model_from_file(
|
||||||
|
self.model_path.encode("utf-8"), self.params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with suppress_stdout_stderr():
|
||||||
self.model = llama_cpp.llama_load_model_from_file(
|
self.model = llama_cpp.llama_load_model_from_file(
|
||||||
self.model_path.encode("utf-8"), self.params
|
self.model_path.encode("utf-8"), self.params
|
||||||
)
|
)
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
|
|
||||||
|
if verbose:
|
||||||
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
|
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
|
||||||
|
else:
|
||||||
|
with suppress_stdout_stderr():
|
||||||
|
print("here")
|
||||||
|
self.ctx = llama_cpp.llama_new_context_with_model(
|
||||||
|
self.model, self.params
|
||||||
|
)
|
||||||
|
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
|
@ -959,9 +980,7 @@ class Llama:
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
token_end_position += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token]))
|
||||||
# Check if stop sequence is in the token
|
# Check if stop sequence is in the token
|
||||||
if token_end_position >= (
|
if token_end_position >= (remaining_length - first_stop_position):
|
||||||
remaining_length - first_stop_position
|
|
||||||
):
|
|
||||||
break
|
break
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
|
@ -1503,10 +1522,10 @@ class Llama:
|
||||||
return self._convert_text_completion_to_chat(completion)
|
return self._convert_text_completion_to_chat(completion)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.model is not None:
|
if hasattr(self, "model") and self.model is not None:
|
||||||
llama_cpp.llama_free_model(self.model)
|
llama_cpp.llama_free_model(self.model)
|
||||||
self.model = None
|
self.model = None
|
||||||
if self.ctx is not None:
|
if hasattr(self, "ctx") and self.ctx is not None:
|
||||||
llama_cpp.llama_free(self.ctx)
|
llama_cpp.llama_free(self.ctx)
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
|
|
||||||
|
|
|
@ -103,6 +103,10 @@ class Settings(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
description="TEMPORARY",
|
description="TEMPORARY",
|
||||||
)
|
)
|
||||||
|
mul_mat_q: Optional[bool] = Field(
|
||||||
|
default=None,
|
||||||
|
description="TEMPORARY",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(TypedDict):
|
class ErrorResponse(TypedDict):
|
||||||
|
|
38
llama_cpp/utils.py
Normal file
38
llama_cpp/utils.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class suppress_stdout_stderr(object):
|
||||||
|
# Oddly enough this works better than the contextlib version
|
||||||
|
def __enter__(self):
|
||||||
|
self.outnull_file = open(os.devnull, "w")
|
||||||
|
self.errnull_file = open(os.devnull, "w")
|
||||||
|
|
||||||
|
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||||
|
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||||
|
|
||||||
|
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||||
|
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||||
|
|
||||||
|
self.old_stdout = sys.stdout
|
||||||
|
self.old_stderr = sys.stderr
|
||||||
|
|
||||||
|
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||||
|
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||||
|
|
||||||
|
sys.stdout = self.outnull_file
|
||||||
|
sys.stderr = self.errnull_file
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *_):
|
||||||
|
sys.stdout = self.old_stdout
|
||||||
|
sys.stderr = self.old_stderr
|
||||||
|
|
||||||
|
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||||
|
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||||
|
|
||||||
|
os.close(self.old_stdout_fileno)
|
||||||
|
os.close(self.old_stderr_fileno)
|
||||||
|
|
||||||
|
self.outnull_file.close()
|
||||||
|
self.errnull_file.close()
|
16
poetry.lock
generated
16
poetry.lock
generated
|
@ -384,17 +384,17 @@ test = ["pytest (>=6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.100.1"
|
version = "0.101.0"
|
||||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "fastapi-0.100.1-py3-none-any.whl", hash = "sha256:ec6dd52bfc4eff3063cfcd0713b43c87640fefb2687bbbe3d8a08d94049cdf32"},
|
{file = "fastapi-0.101.0-py3-none-any.whl", hash = "sha256:494eb3494d89e8079c20859d7ca695f66eaccc40f46fe8c75ab6186d15f05ffd"},
|
||||||
{file = "fastapi-0.100.1.tar.gz", hash = "sha256:522700d7a469e4a973d92321ab93312448fbe20fca9c8da97effc7e7bc56df23"},
|
{file = "fastapi-0.101.0.tar.gz", hash = "sha256:ca2ae65fe42f6a34b5cf6c994337149154b1b400c39809d7b2dccdceb5ae77af"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<3.0.0"
|
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
|
||||||
starlette = ">=0.27.0,<0.28.0"
|
starlette = ">=0.27.0,<0.28.0"
|
||||||
typing-extensions = ">=4.5.0"
|
typing-extensions = ">=4.5.0"
|
||||||
|
|
||||||
|
@ -744,13 +744,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mkdocs"
|
name = "mkdocs"
|
||||||
version = "1.5.1"
|
version = "1.5.2"
|
||||||
description = "Project documentation with Markdown."
|
description = "Project documentation with Markdown."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "mkdocs-1.5.1-py3-none-any.whl", hash = "sha256:67e889f8d8ba1fe5decdfc59f5f8f21d6a8925a129339e93dede303bdea03a98"},
|
{file = "mkdocs-1.5.2-py3-none-any.whl", hash = "sha256:60a62538519c2e96fe8426654a67ee177350451616118a41596ae7c876bb7eac"},
|
||||||
{file = "mkdocs-1.5.1.tar.gz", hash = "sha256:f2f323c62fffdf1b71b84849e39aef56d6852b3f0a5571552bca32cefc650209"},
|
{file = "mkdocs-1.5.2.tar.gz", hash = "sha256:70d0da09c26cff288852471be03c23f0f521fc15cf16ac89c7a3bfb9ae8d24f9"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1757,4 +1757,4 @@ server = ["fastapi", "pydantic-settings", "sse-starlette", "uvicorn"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8.1"
|
python-versions = "^3.8.1"
|
||||||
content-hash = "6718d680fa89f9518a232c1110ba43958d3e21c54c4dbd9129effa4f40a02b81"
|
content-hash = "4bfb67dfb72b02c845376211f7f958b2ece8c985944fbd03d246c858e846ddf6"
|
||||||
|
|
|
@ -25,7 +25,7 @@ pydantic-settings = { version = ">=2.0.1", optional = true }
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
black = "^23.7.0"
|
black = "^23.7.0"
|
||||||
twine = "^4.0.2"
|
twine = "^4.0.2"
|
||||||
mkdocs = "^1.4.3"
|
mkdocs = "^1.5.2"
|
||||||
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
|
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
|
||||||
mkdocs-material = "^9.1.21"
|
mkdocs-material = "^9.1.21"
|
||||||
pytest = "^7.4.0"
|
pytest = "^7.4.0"
|
||||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit 8183159cf3def112f6d1fe94815fce70e1bffa12
|
Subproject commit f5bfea0580e417f99850d5456ca541d871a3e48c
|
Loading…
Add table
Reference in a new issue