This commit is contained in:
baalajimaestro 2024-02-23 18:03:03 +05:30
commit da343412ee
Signed by: baalajimaestro
GPG key ID: F93C394FE9BBAFD5
15 changed files with 1610 additions and 1160 deletions

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
*.local
.python-version
.vscode/

View file

@ -7,6 +7,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.48]
- feat: Update llama.cpp to ggerganov/llama.cpp@15499eb94227401bdc8875da6eb85c15d37068f7
- feat: Add Google's Gemma formatting via chat_format="gemma" by @alvarobartt in #1210
- feat: support minItems/maxItems in JSON grammar converter by @nopperl in 3921e10770996d95a9eb22c8248bacef39f69365
- fix: Update from_pretrained defaults to match hf_hub_download and pull to local cache folder by @abetlen in e6d6260a91b7831733f7d1f73c7af46a3e8185ed
- fix: Raise exceptions when llama model or context fails to load by @abetlen in dd22010e85265ae840c76ec835d67a29ed852722
- docs: Update README.md to fix pip install llama cpp server by @audip in #1187
## [0.2.47]
- feat: Update llama.cpp to ggerganov/llama.cpp@973053d8b0d04809836b3339a50f68d9c842de90
## [0.2.46]
- feat: Update llama.cpp to ggerganov/llama.cpp@ba2135ccae7462470b3865c6e41d2e1d734eac05
- feat: Pull models directly from huggingface by @abetlen in #1206
- feat(low-level-api): Improve API static type-safety and performance. Low level api functions are positional args only now. by @abetlen in #1205
## [0.2.45]
- feat: Update llama.cpp to ggerganov/llama.cpp@89febfed9322c8849520dc63c93ee4f5fd72556e
## [0.2.44]
- feat: Update llama.cpp to ggerganov/llama.cpp@4524290e87b8e107cc2b56e1251751546f4b9051

View file

@ -12,6 +12,9 @@ deps:
build:
python3 -m pip install --verbose -e .
build.debug:
CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Debug" python3 -m pip install --verbose --config-settings=cmake.verbose=true --config-settings=logging.level=INFO --config-settings=install.strip=false --editable .
build.cuda:
CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install --verbose -e .

169
README.md
View file

@ -12,60 +12,94 @@ This package provides:
- Low-level access to C API via `ctypes` interface.
- High-level Python API for text completion
- OpenAI-like API
- [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp)
- [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html)
- OpenAI-like API
- [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp)
- [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html)
- OpenAI compatible web server
- [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)
- [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling)
- [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models)
- [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support)
- [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)
- [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling)
- [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models)
- [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support)
Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest](https://llama-cpp-python.readthedocs.io/en/latest).
## Installation
`llama-cpp-python` can be installed directly from PyPI as a source distribution by running:
Requirements:
- Python 3.8+
- C compiler
- Linux: gcc or clang
- Windows: Visual Studio or MinGW
- MacOS: Xcode
To install the package, run:
```bash
pip install llama-cpp-python
```
This will build `llama.cpp` from source using cmake and your system's c compiler (required) and install the library alongside this python package.
This will also build `llama.cpp` from source and install it alongside this python package.
If you run into issues during installation add the `--verbose` flag to the `pip install` command to see the full cmake build log.
If this fails, add `--verbose` to the `pip install` see the full cmake build log.
### Installation with Specific Hardware Acceleration (BLAS, CUDA, Metal, etc)
### Installation Configuration
The default pip install behaviour is to build `llama.cpp` for CPU only on Linux and Windows and use Metal on MacOS.
`llama.cpp` supports a number of hardware acceleration backends to speed up inference as well as backend specific options. See the [llama.cpp README](https://github.com/ggerganov/llama.cpp#build) for a full list.
`llama.cpp` supports a number of hardware acceleration backends depending including OpenBLAS, cuBLAS, CLBlast, HIPBLAS, and Metal.
See the [llama.cpp README](https://github.com/ggerganov/llama.cpp#build) for a full list of supported backends.
All `llama.cpp` cmake build options can be set via the `CMAKE_ARGS` environment variable or via the `--config-settings / -C` cli flag during installation.
All of these backends are supported by `llama-cpp-python` and can be enabled by setting the `CMAKE_ARGS` environment variable before installing.
On Linux and Mac you set the `CMAKE_ARGS` like this:
<details open>
<summary>Environment Variables</summary>
```bash
CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-python
# Linux and Mac
CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" \
pip install llama-cpp-python
```
On Windows you can set the `CMAKE_ARGS` like this:
```ps
```powershell
# Windows
$env:CMAKE_ARGS = "-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS"
pip install llama-cpp-python
```
</details>
#### OpenBLAS
<details>
<summary>CLI / requirements.txt</summary>
To install with OpenBLAS, set the `LLAMA_BLAS and LLAMA_BLAS_VENDOR` environment variables before installing:
They can also be set via `pip install -C / --config-settings` command and saved to a `requirements.txt` file:
```bash
pip install --upgrade pip # ensure pip is up to date
pip install llama-cpp-python \
-C cmake.args="-DLLAMA_BLAS=ON;-DLLAMA_BLAS_VENDOR=OpenBLAS"
```
```txt
# requirements.txt
llama-cpp-python -C cmake.args="-DLLAMA_BLAS=ON;-DLLAMA_BLAS_VENDOR=OpenBLAS"
```
</details>
### Supported Backends
Below are some common backends, their build commands and any additional environment variables required.
<details open>
<summary>OpenBLAS (CPU)</summary>
To install with OpenBLAS, set the `LLAMA_BLAS` and `LLAMA_BLAS_VENDOR` environment variables before installing:
```bash
CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-python
```
</details>
#### cuBLAS
<details>
<summary>cuBLAS (CUDA)</summary>
To install with cuBLAS, set the `LLAMA_CUBLAS=on` environment variable before installing:
@ -73,7 +107,10 @@ To install with cuBLAS, set the `LLAMA_CUBLAS=on` environment variable before in
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
```
#### Metal
</details>
<details>
<summary>Metal</summary>
To install with Metal (MPS), set the `LLAMA_METAL=on` environment variable before installing:
@ -81,7 +118,10 @@ To install with Metal (MPS), set the `LLAMA_METAL=on` environment variable befor
CMAKE_ARGS="-DLLAMA_METAL=on" pip install llama-cpp-python
```
#### CLBlast
</details>
<details>
<summary>CLBlast (OpenCL)</summary>
To install with CLBlast, set the `LLAMA_CLBLAST=on` environment variable before installing:
@ -89,7 +129,10 @@ To install with CLBlast, set the `LLAMA_CLBLAST=on` environment variable before
CMAKE_ARGS="-DLLAMA_CLBLAST=on" pip install llama-cpp-python
```
#### hipBLAS
</details>
<details>
<summary>hipBLAS (ROCm)</summary>
To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on` environment variable before installing:
@ -97,7 +140,10 @@ To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on`
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" pip install llama-cpp-python
```
#### Vulkan
</details>
<details>
<summary>Vulkan</summary>
To install with Vulkan support, set the `LLAMA_VULKAN=on` environment variable before installing:
@ -105,15 +151,20 @@ To install with Vulkan support, set the `LLAMA_VULKAN=on` environment variable b
CMAKE_ARGS="-DLLAMA_VULKAN=on" pip install llama-cpp-python
```
#### Kompute
</details>
<details>
<summary>Kompute</summary>
To install with Kompute support, set the `LLAMA_KOMPUTE=on` environment variable before installing:
```bash
CMAKE_ARGS="-DLLAMA_KOMPUTE=on" pip install llama-cpp-python
```
</details>
#### SYCL
<details>
<summary>SYCL</summary>
To install with SYCL support, set the `LLAMA_SYCL=on` environment variable before installing:
@ -121,9 +172,14 @@ To install with SYCL support, set the `LLAMA_SYCL=on` environment variable befor
source /opt/intel/oneapi/setvars.sh
CMAKE_ARGS="-DLLAMA_SYCL=on -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx" pip install llama-cpp-python
```
</details>
### Windows Notes
<details>
<summary>Error: Can't find 'nmake' or 'CMAKE_C_COMPILER'</summary>
If you run into issues where it complains it can't find `'nmake'` `'?'` or CMAKE_C_COMPILER, you can extract w64devkit as [mentioned in llama.cpp repo](https://github.com/ggerganov/llama.cpp#openblas) and add those manually to CMAKE_ARGS before running `pip` install:
```ps
@ -132,12 +188,14 @@ $env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on -DCMAKE_C_COMPILER=C:/w64devkit/bin/gcc.e
```
See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to use.
</details>
### MacOS Notes
Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/)
#### M1 Mac Performance Issue
<details>
<summary>M1 Mac Performance Issue</summary>
Note: If you are using Apple Silicon (M1) Mac, make sure you have installed a version of Python that supports arm64 architecture. For example:
@ -147,24 +205,21 @@ bash Miniforge3-MacOSX-arm64.sh
```
Otherwise, while installing it will build the llama.cpp x86 version which will be 10x slower on Apple Silicon (M1) Mac.
</details>
#### M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))`
<details>
<summary>M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))`</summary>
Try installing with
```bash
CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_APPLE_SILICON_PROCESSOR=arm64 -DLLAMA_METAL=on" pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python
```
</details>
### Upgrading and Reinstalling
To upgrade or rebuild `llama-cpp-python` add the following flags to ensure that the package is rebuilt correctly:
```bash
pip install llama-cpp-python --upgrade --force-reinstall --no-cache-dir
```
This will ensure that all source files are re-built with the most recently set `CMAKE_ARGS` flags.
To upgrade and rebuild `llama-cpp-python` add `--upgrade --force-reinstall --no-cache-dir` flags to the `pip install` command to ensure the package is rebuilt from source.
## High-level API
@ -212,6 +267,21 @@ Below is a short example demonstrating how to use the high-level API to for basi
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
### Pulling models from Hugging Face Hub
You can download `Llama` models in `gguf` format directly from Hugging Face using the [`from_pretrained`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.from_pretrained) method.
You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
```python
llm = Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q8_0.gguf",
verbose=False
)
```
By default [`from_pretrained`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.from_pretrained) will download the model to the huggingface cache directory, you can then manage installed model files with the [`huggingface-cli`](https://huggingface.co/docs/huggingface_hub/en/guides/cli) tool.
### Chat Completion
The high-level API also provides a simple interface for chat completion.
@ -237,13 +307,16 @@ Note that `chat_format` option must be set for the particular model you are usin
Chat completion is available through the [`create_chat_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion) method of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
For OpenAI API v1 compatibility, you use the [`create_chat_completion_openai_v1`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion_openai_v1) method which will return pydantic models instead of dicts.
### JSON and JSON Schema Mode
If you want to constrain chat responses to only valid JSON or a specific JSON Schema you can use the `response_format` argument to the `create_chat_completion` method.
To constrain chat responses to only valid JSON or a specific JSON Schema use the `response_format` argument in [`create_chat_completion`](http://localhost:8000/api-reference/#llama_cpp.Llama.create_chat_completion).
#### JSON Mode
The following example will constrain the response to be valid JSON.
The following example will constrain the response to valid JSON strings only.
```python
>>> from llama_cpp import Llama
@ -265,7 +338,7 @@ The following example will constrain the response to be valid JSON.
#### JSON Schema Mode
To constrain the response to a specific JSON Schema, you can use the `schema` property of the `response_format` argument.
To constrain the response further to a specific JSON Schema add the schema to the `schema` property of the `response_format` argument.
```python
>>> from llama_cpp import Llama
@ -400,7 +473,7 @@ llama = Llama(
### Embeddings
`llama-cpp-python` supports generating embeddings from the text.
To generate text embeddings use [`create_embedding`](http://localhost:8000/api-reference/#llama_cpp.Llama.create_embedding).
```python
import llama_cpp
@ -409,7 +482,7 @@ llm = llama_cpp.Llama(model_path="path/to/model.gguf", embeddings=True)
embeddings = llm.create_embedding("Hello, world!")
# or batched
# or create multiple embeddings at once
embeddings = llm.create_embedding(["Hello, world!", "Goodbye, world!"])
```
@ -432,14 +505,14 @@ This allows you to use llama.cpp compatible models with any OpenAI compatible cl
To install the server package and get started:
```bash
pip install llama-cpp-python[server]
pip install 'llama-cpp-python[server]'
python3 -m llama_cpp.server --model models/7B/llama-model.gguf
```
Similar to Hardware Acceleration section above, you can also install with GPU (cuBLAS) support like this:
```bash
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python[server]
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install 'llama-cpp-python[server]'
python3 -m llama_cpp.server --model models/7B/llama-model.gguf --n_gpu_layers 35
```
@ -486,7 +559,7 @@ Below is a short example demonstrating how to use the low-level API to tokenize
```python
>>> import llama_cpp
>>> import ctypes
>>> llama_cpp.llama_backend_init(numa=False) # Must be called once at the start of each program
>>> llama_cpp.llama_backend_init(False) # Must be called once at the start of each program
>>> params = llama_cpp.llama_context_default_params()
# use bytes for char * params
>>> model = llama_cpp.llama_load_model_from_file(b"./models/7b/llama-model.gguf", params)
@ -494,7 +567,7 @@ Below is a short example demonstrating how to use the low-level API to tokenize
>>> max_tokens = params.n_ctx
# use ctypes arrays for array params
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()
>>> n_tokens = llama_cpp.llama_tokenize(ctx, b"Q: Name the planets in the solar system? A: ", tokens, max_tokens, add_bos=llama_cpp.c_bool(True))
>>> n_tokens = llama_cpp.llama_tokenize(ctx, b"Q: Name the planets in the solar system? A: ", tokens, max_tokens, llama_cpp.c_bool(True))
>>> llama_cpp.llama_free(ctx)
```

View file

@ -21,11 +21,13 @@ High-level Python bindings for llama.cpp.
- create_completion
- __call__
- create_chat_completion
- create_chat_completion_openai_v1
- set_cache
- save_state
- load_state
- token_bos
- token_eos
- from_pretrained
show_root_heading: true
::: llama_cpp.LlamaGrammar

View file

@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.2.44"
__version__ = "0.2.48"

View file

@ -51,6 +51,9 @@ class _LlamaModel:
self.path_model.encode("utf-8"), self.params
)
if self.model is None:
raise ValueError(f"Failed to load model from file: {path_model}")
def __del__(self):
if self.model is not None and self._llama_free_model is not None:
self._llama_free_model(self.model)
@ -79,7 +82,7 @@ class _LlamaModel:
def desc(self) -> str:
assert self.model is not None
buf = ctypes.create_string_buffer(1024)
llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
llama_cpp.llama_model_desc(self.model, buf, 1024)
return buf.value.decode("utf-8")
def size(self) -> int:
@ -108,7 +111,7 @@ class _LlamaModel:
scale,
path_base_model.encode("utf-8")
if path_base_model is not None
else llama_cpp.c_char_p(0),
else ctypes.c_char_p(0),
n_threads,
)
@ -181,7 +184,7 @@ class _LlamaModel:
def token_to_piece(self, token: int) -> bytes:
assert self.model is not None
buf = ctypes.create_string_buffer(32)
llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
return bytes(buf)
def detokenize(self, tokens: List[int]) -> bytes:
@ -258,6 +261,9 @@ class _LlamaContext:
self.model.model, self.params
)
if self.ctx is None:
raise ValueError("Failed to create llama_context")
def __del__(self):
if self.ctx is not None and self._llama_free is not None:
self._llama_free(self.ctx)
@ -303,8 +309,8 @@ class _LlamaContext:
assert self.ctx is not None
assert batch.batch is not None
return_code = llama_cpp.llama_decode(
ctx=self.ctx,
batch=batch.batch,
self.ctx,
batch.batch,
)
if return_code != 0:
raise RuntimeError(f"llama_decode returned {return_code}")
@ -343,7 +349,7 @@ class _LlamaContext:
assert self.ctx is not None
llama_cpp.llama_sample_repetition_penalties(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
last_tokens_data,
penalty_last_n,
penalty_repeat,
@ -361,7 +367,7 @@ class _LlamaContext:
assert guidance_ctx.ctx is not None
llama_cpp.llama_sample_classifier_free_guidance(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
guidance_ctx.ctx,
scale,
)
@ -370,25 +376,25 @@ class _LlamaContext:
assert self.ctx is not None
llama_cpp.llama_sample_softmax(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
)
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_top_k(
self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
)
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_top_p(
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_min_p(
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_tail_free(
@ -396,7 +402,7 @@ class _LlamaContext:
):
assert self.ctx is not None
llama_cpp.llama_sample_tail_free(
self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
)
def sample_typical(
@ -404,13 +410,13 @@ class _LlamaContext:
):
assert self.ctx is not None
llama_cpp.llama_sample_typical(
self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
assert self.ctx is not None
llama_cpp.llama_sample_temp(
self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
self.ctx, llama_cpp.byref(candidates.candidates), temp
)
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
@ -418,7 +424,7 @@ class _LlamaContext:
assert grammar.grammar is not None
llama_cpp.llama_sample_grammar(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
grammar.grammar,
)
@ -428,12 +434,12 @@ class _LlamaContext:
tau: float,
eta: float,
m: int,
mu: ctypes._Pointer[ctypes.c_float], # type: ignore
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token_mirostat(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
tau,
eta,
m,
@ -441,12 +447,12 @@ class _LlamaContext:
)
def sample_token_mirostat_v2(
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
) -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token_mirostat_v2(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
tau,
eta,
mu,
@ -456,14 +462,14 @@ class _LlamaContext:
assert self.ctx is not None
return llama_cpp.llama_sample_token_greedy(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
)
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token(
self.ctx,
ctypes.byref(candidates.candidates), # type: ignore
llama_cpp.byref(candidates.candidates),
)
# Grammar
@ -493,7 +499,7 @@ class _LlamaBatch:
def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
):
self.n_tokens = n_tokens
self._n_tokens = n_tokens
self.embd = embd
self.n_seq_max = n_seq_max
self.verbose = verbose
@ -502,7 +508,7 @@ class _LlamaBatch:
self.batch = None
self.batch = llama_cpp.llama_batch_init(
self.n_tokens, self.embd, self.n_seq_max
self._n_tokens, self.embd, self.n_seq_max
)
def __del__(self):
@ -560,7 +566,7 @@ class _LlamaTokenDataArray:
size=self.n_vocab,
sorted=False,
)
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
def copy_logits(self, logits: npt.NDArray[np.single]):
@ -570,12 +576,13 @@ class _LlamaTokenDataArray:
self.candidates.data = self.candidates_data.ctypes.data_as(
llama_cpp.llama_token_data_p
)
self.candidates.sorted = llama_cpp.c_bool(False)
self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
self.candidates.sorted = ctypes.c_bool(False)
self.candidates.size = ctypes.c_size_t(self.n_vocab)
# Python wrappers over common/common
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
assert model.model is not None
n_tokens = len(text) + 1 if add_bos else len(text)
result = (llama_cpp.llama_token * n_tokens)()
n_tokens = llama_cpp.llama_tokenize(
@ -747,7 +754,7 @@ class _LlamaSamplingContext:
ctx_main.sample_repetition_penalties(
token_data_array,
# TODO: Only create this once
(llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
(llama_cpp.llama_token * len(self.prev))(*self.prev),
self.params.penalty_last_n,
self.params.penalty_repeat,
self.params.penalty_freq,

View file

@ -4,6 +4,8 @@ import os
import sys
import uuid
import time
import json
import fnmatch
import multiprocessing
from typing import (
List,
@ -16,6 +18,7 @@ from typing import (
Callable,
)
from collections import deque
from pathlib import Path
import ctypes
@ -29,10 +32,7 @@ from .llama_cache import (
LlamaDiskCache, # type: ignore
LlamaRAMCache, # type: ignore
)
from .llama_tokenizer import (
BaseLlamaTokenizer,
LlamaTokenizer
)
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format
@ -50,9 +50,7 @@ from ._internals import (
_LlamaSamplingContext, # type: ignore
)
from ._logger import set_verbose
from ._utils import (
suppress_stdout_stderr
)
from ._utils import suppress_stdout_stderr
class Llama:
@ -189,7 +187,11 @@ class Llama:
Llama.__backend_initialized = True
if isinstance(numa, bool):
self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
self.numa = (
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
if numa
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
)
else:
self.numa = numa
@ -246,9 +248,9 @@ class Llama:
else:
raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array[
-1
].key = b"\0" # ensure sentinel element is zeroed
self._kv_overrides_array[-1].key = (
b"\0" # ensure sentinel element is zeroed
)
self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ???
@ -256,7 +258,7 @@ class Llama:
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)
# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
@ -289,7 +291,9 @@ class Llama:
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q
self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
self.context_params.logits_all = (
logits_all if draft_model is None else True
) # Must be set to True for speculative decoding
self.context_params.embedding = embedding
self.context_params.offload_kqv = offload_kqv
@ -379,8 +383,14 @@ class Llama:
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
if (
self.chat_format is None
and self.chat_handler is None
and "tokenizer.chat_template" in self.metadata
):
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
self.metadata
)
if chat_format is not None:
self.chat_format = chat_format
@ -406,9 +416,7 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token
template=template, eos_token=eos_token, bos_token=bos_token
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None:
@ -459,7 +467,9 @@ class Llama:
"""
return self.tokenizer_.tokenize(text, add_bos, special)
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
def detokenize(
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
) -> bytes:
"""Detokenize a list of tokens.
Args:
@ -565,7 +575,7 @@ class Llama:
logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
else logits_processor(self._input_ids[:idx + 1], logits)
else logits_processor(self._input_ids[: idx + 1], logits)
)
sampling_params = _LlamaSamplingParams(
@ -707,7 +717,9 @@ class Llama:
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
draft_tokens = self.draft_model(
self.input_ids[: self.n_tokens + len(tokens)]
)
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
@ -792,6 +804,7 @@ class Llama:
# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(n_seq: int):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
@ -800,9 +813,9 @@ class Llama:
# store embeddings
for i in range(n_seq):
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
self._ctx.ctx, i
)[:n_embd]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
@ -1669,12 +1682,13 @@ class Llama:
"""
try:
from openai.types.chat import ChatCompletion, ChatCompletionChunk
stream = kwargs.get("stream", False) # type: ignore
stream = kwargs.get("stream", False) # type: ignore
assert isinstance(stream, bool)
if stream:
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
else:
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
except ImportError:
raise ImportError(
"To use create_chat_completion_openai_v1, you must install the openai package."
@ -1804,7 +1818,7 @@ class Llama:
self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens
state_size = state.llama_state_size
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
LLamaStateArrayType = ctypes.c_uint8 * state_size
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
@ -1866,7 +1880,100 @@ class Llama:
break
return longest_prefix
@classmethod
def from_pretrained(
cls,
repo_id: str,
filename: Optional[str],
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
**kwargs: Any,
) -> "Llama":
"""Create a Llama model from a pretrained model name or path.
This method requires the huggingface-hub package.
You can install it with `pip install huggingface-hub`.
Args:
repo_id: The model repo id.
filename: A filename or glob pattern to match the model file in the repo.
local_dir: The local directory to save the model to.
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
**kwargs: Additional keyword arguments to pass to the Llama constructor.
Returns:
A Llama model."""
try:
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.utils import validate_repo_id
except ImportError:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
)
validate_repo_id(repo_id)
hffs = HfFileSystem()
files = [
file["name"] if isinstance(file, dict) else file
for file in hffs.ls(repo_id)
]
# split each file into repo_id, subfolder, filename
file_list: List[str] = []
for file in files:
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
if len(matching_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {filename}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)
if len(matching_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {filename}\n\n"
f"Available Files:\n{json.dumps(files)}"
)
(matching_file,) = matching_files
subfolder = str(Path(matching_file).parent)
filename = Path(matching_file).name
# download the file
hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir,
)
if local_dir is None:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir,
local_files_only=True,
)
else:
model_path = os.path.join(local_dir, filename)
return cls(
model_path=model_path,
**kwargs,
)
class LlamaState:

View file

@ -14,6 +14,7 @@ import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
from ._logger import logger
from ._utils import suppress_stdout_stderr, Singleton
### Common Chat Templates and Special Tokens ###
@ -993,6 +994,26 @@ def format_saiga(
return ChatFormatterResponse(prompt=_prompt.strip())
# Chat format for Google's Gemma models, see more details and available models:
# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
@register_chat_format("gemma")
def format_gemma(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
system_message = _get_system_message(messages)
if system_message is not None and system_message != "":
logger.debug(
"`role='system'` messages are not allowed on Google's Gemma models."
)
_roles = dict(user="<start_of_turn>user\n", assistant="<start_of_turn>model\n")
_sep = "<end_of_turn>\n"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
# Tricky chat formats that require custom chat handlers

File diff suppressed because it is too large Load diff

View file

@ -1498,9 +1498,21 @@ class SchemaConverter:
item_rule_name = self.visit(
schema["items"], f'{name}{"-" if name else ""}item'
)
rule = (
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
)
list_item_operator = f'("," space {item_rule_name})'
successive_items = ""
min_items = schema.get("minItems", 0)
if min_items > 0:
first_item = f"({item_rule_name})"
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
else:
first_item = f"({item_rule_name})?"
max_items = schema.get("maxItems")
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
rule = f'"[" space {first_item} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
else:

View file

@ -5,21 +5,15 @@ from ctypes import (
c_bool,
c_char_p,
c_int,
c_int8,
c_int32,
c_uint8,
c_uint32,
c_size_t,
c_float,
c_double,
c_void_p,
POINTER,
_Pointer, # type: ignore
Structure,
Array,
)
import pathlib
from typing import List, Union
from typing import List, Union, NewType, Optional
import llama_cpp.llama_cpp as llama_cpp
@ -67,7 +61,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
return ctypes.CDLL(str(_lib_path), **cdll_args)
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@ -88,7 +82,8 @@ _libllava = _load_shared_library(_libllava_base_name)
################################################
# struct clip_ctx;
clip_ctx_p = c_void_p
clip_ctx_p = NewType("clip_ctx_p", int)
clip_ctx_p_ctypes = c_void_p
# struct llava_image_embed {
# float * embed;
@ -102,43 +97,48 @@ class llava_image_embed(Structure):
# /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool:
return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip)
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
...
_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p]
_libllava.llava_validate_embed_size.restype = c_bool
llava_validate_embed_size = _libllava.llava_validate_embed_size
llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes]
llava_validate_embed_size.restype = c_bool
# /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]":
return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length)
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
...
_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int]
_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes
llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int]
llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
# /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]":
return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path)
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
...
_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p]
_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename
llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p]
llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"):
return _libllava.llava_image_embed_free(embed)
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
...
_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
_libllava.llava_image_embed_free.restype = None
llava_image_embed_free = _libllava.llava_image_embed_free
llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
llava_image_embed_free.restype = None
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool:
return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past)
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
...
_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)]
_libllava.llava_eval_image_embed.restype = c_bool
llava_eval_image_embed = _libllava.llava_eval_image_embed
llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)]
llava_eval_image_embed.restype = c_bool
################################################
@ -148,16 +148,18 @@ _libllava.llava_eval_image_embed.restype = c_bool
# /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p:
return _libllava.clip_model_load(fname, verbosity)
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
...
_libllava.clip_model_load.argtypes = [c_char_p, c_int]
_libllava.clip_model_load.restype = clip_ctx_p
clip_model_load = _libllava.clip_model_load
clip_model_load.argtypes = [c_char_p, c_int]
clip_model_load.restype = clip_ctx_p_ctypes
# /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx);
def clip_free(ctx: clip_ctx_p):
return _libllava.clip_free(ctx)
def clip_free(ctx: clip_ctx_p, /):
...
_libllava.clip_free.argtypes = [clip_ctx_p]
_libllava.clip_free.restype = None
clip_free = _libllava.clip_free
clip_free.argtypes = [clip_ctx_p_ctypes]
clip_free.restype = None

View file

@ -72,4 +72,4 @@ Documentation = "https://llama-cpp-python.readthedocs.io/en/latest/"
Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
[tool.pytest.ini_options]
addopts = "--ignore=vendor"
testpaths = "tests"

View file

@ -54,7 +54,7 @@ def mock_llama(monkeypatch):
output_tokens = llama.tokenize(
output_text.encode("utf-8"), add_bos=True, special=True
)
logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0)
logits = (ctypes.c_float * (n_vocab * n_ctx))(-100.0)
for i in range(n_ctx):
output_idx = i + 1 # logits for first tokens predict second token
if output_idx < len(output_tokens):
@ -90,9 +90,9 @@ def mock_llama(monkeypatch):
assert n > 0, "mock_llama_decode not called"
assert last_n_tokens > 0, "mock_llama_decode not called"
# Return view of logits for last_n_tokens
return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address(
return (ctypes.c_float * (last_n_tokens * n_vocab)).from_address(
ctypes.addressof(logits)
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float)
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(ctypes.c_float)
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit f53119cec4f073b6d214195ecbe1fad3abdf2b34
Subproject commit 15499eb94227401bdc8875da6eb85c15d37068f7