diff --git a/.gitignore b/.gitignore
index 51f3572..9d68dbc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+*.local
+
.python-version
.vscode/
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2c8ff8b..c539ade 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+## [0.2.48]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@15499eb94227401bdc8875da6eb85c15d37068f7
+- feat: Add Google's Gemma formatting via chat_format="gemma" by @alvarobartt in #1210
+- feat: support minItems/maxItems in JSON grammar converter by @nopperl in 3921e10770996d95a9eb22c8248bacef39f69365
+- fix: Update from_pretrained defaults to match hf_hub_download and pull to local cache folder by @abetlen in e6d6260a91b7831733f7d1f73c7af46a3e8185ed
+- fix: Raise exceptions when llama model or context fails to load by @abetlen in dd22010e85265ae840c76ec835d67a29ed852722
+- docs: Update README.md to fix pip install llama cpp server by @audip in #1187
+
+## [0.2.47]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@973053d8b0d04809836b3339a50f68d9c842de90
+
+## [0.2.46]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@ba2135ccae7462470b3865c6e41d2e1d734eac05
+- feat: Pull models directly from huggingface by @abetlen in #1206
+- feat(low-level-api): Improve API static type-safety and performance. Low level api functions are positional args only now. by @abetlen in #1205
+
+## [0.2.45]
+
+- feat: Update llama.cpp to ggerganov/llama.cpp@89febfed9322c8849520dc63c93ee4f5fd72556e
+
## [0.2.44]
- feat: Update llama.cpp to ggerganov/llama.cpp@4524290e87b8e107cc2b56e1251751546f4b9051
diff --git a/Makefile b/Makefile
index e2ce4d0..4ae0110 100644
--- a/Makefile
+++ b/Makefile
@@ -12,6 +12,9 @@ deps:
build:
python3 -m pip install --verbose -e .
+build.debug:
+ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Debug" python3 -m pip install --verbose --config-settings=cmake.verbose=true --config-settings=logging.level=INFO --config-settings=install.strip=false --editable .
+
build.cuda:
CMAKE_ARGS="-DLLAMA_CUBLAS=on" python3 -m pip install --verbose -e .
diff --git a/README.md b/README.md
index 7da8e5f..ccd66ea 100644
--- a/README.md
+++ b/README.md
@@ -12,60 +12,94 @@ This package provides:
- Low-level access to C API via `ctypes` interface.
- High-level Python API for text completion
- - OpenAI-like API
- - [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp)
- - [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html)
+ - OpenAI-like API
+ - [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp)
+ - [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html)
- OpenAI compatible web server
- - [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)
- - [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling)
- - [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models)
- - [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support)
+ - [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)
+ - [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling)
+ - [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models)
+ - [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support)
Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest](https://llama-cpp-python.readthedocs.io/en/latest).
## Installation
-`llama-cpp-python` can be installed directly from PyPI as a source distribution by running:
+Requirements:
+
+ - Python 3.8+
+ - C compiler
+ - Linux: gcc or clang
+ - Windows: Visual Studio or MinGW
+ - MacOS: Xcode
+
+To install the package, run:
```bash
pip install llama-cpp-python
```
-This will build `llama.cpp` from source using cmake and your system's c compiler (required) and install the library alongside this python package.
+This will also build `llama.cpp` from source and install it alongside this python package.
-If you run into issues during installation add the `--verbose` flag to the `pip install` command to see the full cmake build log.
+If this fails, add `--verbose` to the `pip install` see the full cmake build log.
-### Installation with Specific Hardware Acceleration (BLAS, CUDA, Metal, etc)
+### Installation Configuration
-The default pip install behaviour is to build `llama.cpp` for CPU only on Linux and Windows and use Metal on MacOS.
+`llama.cpp` supports a number of hardware acceleration backends to speed up inference as well as backend specific options. See the [llama.cpp README](https://github.com/ggerganov/llama.cpp#build) for a full list.
-`llama.cpp` supports a number of hardware acceleration backends depending including OpenBLAS, cuBLAS, CLBlast, HIPBLAS, and Metal.
-See the [llama.cpp README](https://github.com/ggerganov/llama.cpp#build) for a full list of supported backends.
+All `llama.cpp` cmake build options can be set via the `CMAKE_ARGS` environment variable or via the `--config-settings / -C` cli flag during installation.
-All of these backends are supported by `llama-cpp-python` and can be enabled by setting the `CMAKE_ARGS` environment variable before installing.
-
-On Linux and Mac you set the `CMAKE_ARGS` like this:
+
+Environment Variables
```bash
-CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-python
+# Linux and Mac
+CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" \
+ pip install llama-cpp-python
```
-On Windows you can set the `CMAKE_ARGS` like this:
-
-```ps
+```powershell
+# Windows
$env:CMAKE_ARGS = "-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS"
pip install llama-cpp-python
```
+
-#### OpenBLAS
+
+CLI / requirements.txt
-To install with OpenBLAS, set the `LLAMA_BLAS and LLAMA_BLAS_VENDOR` environment variables before installing:
+They can also be set via `pip install -C / --config-settings` command and saved to a `requirements.txt` file:
+
+```bash
+pip install --upgrade pip # ensure pip is up to date
+pip install llama-cpp-python \
+ -C cmake.args="-DLLAMA_BLAS=ON;-DLLAMA_BLAS_VENDOR=OpenBLAS"
+```
+
+```txt
+# requirements.txt
+
+llama-cpp-python -C cmake.args="-DLLAMA_BLAS=ON;-DLLAMA_BLAS_VENDOR=OpenBLAS"
+```
+
+
+
+### Supported Backends
+
+Below are some common backends, their build commands and any additional environment variables required.
+
+
+OpenBLAS (CPU)
+
+To install with OpenBLAS, set the `LLAMA_BLAS` and `LLAMA_BLAS_VENDOR` environment variables before installing:
```bash
CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-python
```
+
-#### cuBLAS
+
+cuBLAS (CUDA)
To install with cuBLAS, set the `LLAMA_CUBLAS=on` environment variable before installing:
@@ -73,7 +107,10 @@ To install with cuBLAS, set the `LLAMA_CUBLAS=on` environment variable before in
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
```
-#### Metal
+
+
+
+Metal
To install with Metal (MPS), set the `LLAMA_METAL=on` environment variable before installing:
@@ -81,7 +118,10 @@ To install with Metal (MPS), set the `LLAMA_METAL=on` environment variable befor
CMAKE_ARGS="-DLLAMA_METAL=on" pip install llama-cpp-python
```
-#### CLBlast
+
+
+
+CLBlast (OpenCL)
To install with CLBlast, set the `LLAMA_CLBLAST=on` environment variable before installing:
@@ -89,7 +129,10 @@ To install with CLBlast, set the `LLAMA_CLBLAST=on` environment variable before
CMAKE_ARGS="-DLLAMA_CLBLAST=on" pip install llama-cpp-python
```
-#### hipBLAS
+
+
+
+hipBLAS (ROCm)
To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on` environment variable before installing:
@@ -97,7 +140,10 @@ To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on`
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" pip install llama-cpp-python
```
-#### Vulkan
+
+
+
+Vulkan
To install with Vulkan support, set the `LLAMA_VULKAN=on` environment variable before installing:
@@ -105,15 +151,20 @@ To install with Vulkan support, set the `LLAMA_VULKAN=on` environment variable b
CMAKE_ARGS="-DLLAMA_VULKAN=on" pip install llama-cpp-python
```
-#### Kompute
+
+
+
+Kompute
To install with Kompute support, set the `LLAMA_KOMPUTE=on` environment variable before installing:
```bash
CMAKE_ARGS="-DLLAMA_KOMPUTE=on" pip install llama-cpp-python
```
+
-#### SYCL
+
+SYCL
To install with SYCL support, set the `LLAMA_SYCL=on` environment variable before installing:
@@ -121,9 +172,14 @@ To install with SYCL support, set the `LLAMA_SYCL=on` environment variable befor
source /opt/intel/oneapi/setvars.sh
CMAKE_ARGS="-DLLAMA_SYCL=on -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx" pip install llama-cpp-python
```
+
+
### Windows Notes
+
+Error: Can't find 'nmake' or 'CMAKE_C_COMPILER'
+
If you run into issues where it complains it can't find `'nmake'` `'?'` or CMAKE_C_COMPILER, you can extract w64devkit as [mentioned in llama.cpp repo](https://github.com/ggerganov/llama.cpp#openblas) and add those manually to CMAKE_ARGS before running `pip` install:
```ps
@@ -132,12 +188,14 @@ $env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on -DCMAKE_C_COMPILER=C:/w64devkit/bin/gcc.e
```
See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to use.
+
### MacOS Notes
Detailed MacOS Metal GPU install documentation is available at [docs/install/macos.md](https://llama-cpp-python.readthedocs.io/en/latest/install/macos/)
-#### M1 Mac Performance Issue
+
+M1 Mac Performance Issue
Note: If you are using Apple Silicon (M1) Mac, make sure you have installed a version of Python that supports arm64 architecture. For example:
@@ -147,24 +205,21 @@ bash Miniforge3-MacOSX-arm64.sh
```
Otherwise, while installing it will build the llama.cpp x86 version which will be 10x slower on Apple Silicon (M1) Mac.
+
-#### M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))`
+
+M Series Mac Error: `(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64'))`
Try installing with
```bash
CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_APPLE_SILICON_PROCESSOR=arm64 -DLLAMA_METAL=on" pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python
```
+
### Upgrading and Reinstalling
-To upgrade or rebuild `llama-cpp-python` add the following flags to ensure that the package is rebuilt correctly:
-
-```bash
-pip install llama-cpp-python --upgrade --force-reinstall --no-cache-dir
-```
-
-This will ensure that all source files are re-built with the most recently set `CMAKE_ARGS` flags.
+To upgrade and rebuild `llama-cpp-python` add `--upgrade --force-reinstall --no-cache-dir` flags to the `pip install` command to ensure the package is rebuilt from source.
## High-level API
@@ -212,6 +267,21 @@ Below is a short example demonstrating how to use the high-level API to for basi
Text completion is available through the [`__call__`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__) and [`create_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) methods of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
+### Pulling models from Hugging Face Hub
+
+You can download `Llama` models in `gguf` format directly from Hugging Face using the [`from_pretrained`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.from_pretrained) method.
+You'll need to install the `huggingface-hub` package to use this feature (`pip install huggingface-hub`).
+
+```python
+llm = Llama.from_pretrained(
+ repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
+ filename="*q8_0.gguf",
+ verbose=False
+)
+```
+
+By default [`from_pretrained`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.from_pretrained) will download the model to the huggingface cache directory, you can then manage installed model files with the [`huggingface-cli`](https://huggingface.co/docs/huggingface_hub/en/guides/cli) tool.
+
### Chat Completion
The high-level API also provides a simple interface for chat completion.
@@ -237,13 +307,16 @@ Note that `chat_format` option must be set for the particular model you are usin
Chat completion is available through the [`create_chat_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion) method of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
+For OpenAI API v1 compatibility, you use the [`create_chat_completion_openai_v1`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion_openai_v1) method which will return pydantic models instead of dicts.
+
+
### JSON and JSON Schema Mode
-If you want to constrain chat responses to only valid JSON or a specific JSON Schema you can use the `response_format` argument to the `create_chat_completion` method.
+To constrain chat responses to only valid JSON or a specific JSON Schema use the `response_format` argument in [`create_chat_completion`](http://localhost:8000/api-reference/#llama_cpp.Llama.create_chat_completion).
#### JSON Mode
-The following example will constrain the response to be valid JSON.
+The following example will constrain the response to valid JSON strings only.
```python
>>> from llama_cpp import Llama
@@ -265,7 +338,7 @@ The following example will constrain the response to be valid JSON.
#### JSON Schema Mode
-To constrain the response to a specific JSON Schema, you can use the `schema` property of the `response_format` argument.
+To constrain the response further to a specific JSON Schema add the schema to the `schema` property of the `response_format` argument.
```python
>>> from llama_cpp import Llama
@@ -400,7 +473,7 @@ llama = Llama(
### Embeddings
-`llama-cpp-python` supports generating embeddings from the text.
+To generate text embeddings use [`create_embedding`](http://localhost:8000/api-reference/#llama_cpp.Llama.create_embedding).
```python
import llama_cpp
@@ -409,7 +482,7 @@ llm = llama_cpp.Llama(model_path="path/to/model.gguf", embeddings=True)
embeddings = llm.create_embedding("Hello, world!")
-# or batched
+# or create multiple embeddings at once
embeddings = llm.create_embedding(["Hello, world!", "Goodbye, world!"])
```
@@ -432,14 +505,14 @@ This allows you to use llama.cpp compatible models with any OpenAI compatible cl
To install the server package and get started:
```bash
-pip install llama-cpp-python[server]
+pip install 'llama-cpp-python[server]'
python3 -m llama_cpp.server --model models/7B/llama-model.gguf
```
Similar to Hardware Acceleration section above, you can also install with GPU (cuBLAS) support like this:
```bash
-CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python[server]
+CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install 'llama-cpp-python[server]'
python3 -m llama_cpp.server --model models/7B/llama-model.gguf --n_gpu_layers 35
```
@@ -486,7 +559,7 @@ Below is a short example demonstrating how to use the low-level API to tokenize
```python
>>> import llama_cpp
>>> import ctypes
->>> llama_cpp.llama_backend_init(numa=False) # Must be called once at the start of each program
+>>> llama_cpp.llama_backend_init(False) # Must be called once at the start of each program
>>> params = llama_cpp.llama_context_default_params()
# use bytes for char * params
>>> model = llama_cpp.llama_load_model_from_file(b"./models/7b/llama-model.gguf", params)
@@ -494,7 +567,7 @@ Below is a short example demonstrating how to use the low-level API to tokenize
>>> max_tokens = params.n_ctx
# use ctypes arrays for array params
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()
->>> n_tokens = llama_cpp.llama_tokenize(ctx, b"Q: Name the planets in the solar system? A: ", tokens, max_tokens, add_bos=llama_cpp.c_bool(True))
+>>> n_tokens = llama_cpp.llama_tokenize(ctx, b"Q: Name the planets in the solar system? A: ", tokens, max_tokens, llama_cpp.c_bool(True))
>>> llama_cpp.llama_free(ctx)
```
diff --git a/docs/api-reference.md b/docs/api-reference.md
index 562410f..ab51ef7 100644
--- a/docs/api-reference.md
+++ b/docs/api-reference.md
@@ -21,11 +21,13 @@ High-level Python bindings for llama.cpp.
- create_completion
- __call__
- create_chat_completion
+ - create_chat_completion_openai_v1
- set_cache
- save_state
- load_state
- token_bos
- token_eos
+ - from_pretrained
show_root_heading: true
::: llama_cpp.LlamaGrammar
diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py
index 000e4dd..8c4f400 100644
--- a/llama_cpp/__init__.py
+++ b/llama_cpp/__init__.py
@@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
-__version__ = "0.2.44"
\ No newline at end of file
+__version__ = "0.2.48"
\ No newline at end of file
diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py
index c60fdff..98ad6d7 100644
--- a/llama_cpp/_internals.py
+++ b/llama_cpp/_internals.py
@@ -51,6 +51,9 @@ class _LlamaModel:
self.path_model.encode("utf-8"), self.params
)
+ if self.model is None:
+ raise ValueError(f"Failed to load model from file: {path_model}")
+
def __del__(self):
if self.model is not None and self._llama_free_model is not None:
self._llama_free_model(self.model)
@@ -79,7 +82,7 @@ class _LlamaModel:
def desc(self) -> str:
assert self.model is not None
buf = ctypes.create_string_buffer(1024)
- llama_cpp.llama_model_desc(self.model, buf, 1024) # type: ignore
+ llama_cpp.llama_model_desc(self.model, buf, 1024)
return buf.value.decode("utf-8")
def size(self) -> int:
@@ -108,7 +111,7 @@ class _LlamaModel:
scale,
path_base_model.encode("utf-8")
if path_base_model is not None
- else llama_cpp.c_char_p(0),
+ else ctypes.c_char_p(0),
n_threads,
)
@@ -181,7 +184,7 @@ class _LlamaModel:
def token_to_piece(self, token: int) -> bytes:
assert self.model is not None
buf = ctypes.create_string_buffer(32)
- llama_cpp.llama_token_to_piece(self.model, token, buf, 32) # type: ignore
+ llama_cpp.llama_token_to_piece(self.model, token, buf, 32)
return bytes(buf)
def detokenize(self, tokens: List[int]) -> bytes:
@@ -258,6 +261,9 @@ class _LlamaContext:
self.model.model, self.params
)
+ if self.ctx is None:
+ raise ValueError("Failed to create llama_context")
+
def __del__(self):
if self.ctx is not None and self._llama_free is not None:
self._llama_free(self.ctx)
@@ -303,8 +309,8 @@ class _LlamaContext:
assert self.ctx is not None
assert batch.batch is not None
return_code = llama_cpp.llama_decode(
- ctx=self.ctx,
- batch=batch.batch,
+ self.ctx,
+ batch.batch,
)
if return_code != 0:
raise RuntimeError(f"llama_decode returned {return_code}")
@@ -343,7 +349,7 @@ class _LlamaContext:
assert self.ctx is not None
llama_cpp.llama_sample_repetition_penalties(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
last_tokens_data,
penalty_last_n,
penalty_repeat,
@@ -361,7 +367,7 @@ class _LlamaContext:
assert guidance_ctx.ctx is not None
llama_cpp.llama_sample_classifier_free_guidance(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
guidance_ctx.ctx,
scale,
)
@@ -370,25 +376,25 @@ class _LlamaContext:
assert self.ctx is not None
llama_cpp.llama_sample_softmax(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
)
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_top_k(
- self.ctx, ctypes.byref(candidates.candidates), k, min_keep # type: ignore
+ self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
)
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_top_p(
- self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
+ self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
assert self.ctx is not None
llama_cpp.llama_sample_min_p(
- self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
+ self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_tail_free(
@@ -396,7 +402,7 @@ class _LlamaContext:
):
assert self.ctx is not None
llama_cpp.llama_sample_tail_free(
- self.ctx, ctypes.byref(candidates.candidates), z, min_keep # type: ignore
+ self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep
)
def sample_typical(
@@ -404,13 +410,13 @@ class _LlamaContext:
):
assert self.ctx is not None
llama_cpp.llama_sample_typical(
- self.ctx, ctypes.byref(candidates.candidates), p, min_keep # type: ignore
+ self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
)
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
assert self.ctx is not None
llama_cpp.llama_sample_temp(
- self.ctx, ctypes.byref(candidates.candidates), temp # type: ignore
+ self.ctx, llama_cpp.byref(candidates.candidates), temp
)
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
@@ -418,7 +424,7 @@ class _LlamaContext:
assert grammar.grammar is not None
llama_cpp.llama_sample_grammar(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
grammar.grammar,
)
@@ -428,12 +434,12 @@ class _LlamaContext:
tau: float,
eta: float,
m: int,
- mu: ctypes._Pointer[ctypes.c_float], # type: ignore
+ mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token_mirostat(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
tau,
eta,
m,
@@ -441,12 +447,12 @@ class _LlamaContext:
)
def sample_token_mirostat_v2(
- self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: ctypes._Pointer[ctypes.c_float] # type: ignore
+ self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float]
) -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token_mirostat_v2(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
tau,
eta,
mu,
@@ -456,14 +462,14 @@ class _LlamaContext:
assert self.ctx is not None
return llama_cpp.llama_sample_token_greedy(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
)
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
assert self.ctx is not None
return llama_cpp.llama_sample_token(
self.ctx,
- ctypes.byref(candidates.candidates), # type: ignore
+ llama_cpp.byref(candidates.candidates),
)
# Grammar
@@ -493,7 +499,7 @@ class _LlamaBatch:
def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
):
- self.n_tokens = n_tokens
+ self._n_tokens = n_tokens
self.embd = embd
self.n_seq_max = n_seq_max
self.verbose = verbose
@@ -502,7 +508,7 @@ class _LlamaBatch:
self.batch = None
self.batch = llama_cpp.llama_batch_init(
- self.n_tokens, self.embd, self.n_seq_max
+ self._n_tokens, self.embd, self.n_seq_max
)
def __del__(self):
@@ -560,7 +566,7 @@ class _LlamaTokenDataArray:
size=self.n_vocab,
sorted=False,
)
- self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc)
+ self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
def copy_logits(self, logits: npt.NDArray[np.single]):
@@ -570,12 +576,13 @@ class _LlamaTokenDataArray:
self.candidates.data = self.candidates_data.ctypes.data_as(
llama_cpp.llama_token_data_p
)
- self.candidates.sorted = llama_cpp.c_bool(False)
- self.candidates.size = llama_cpp.c_size_t(self.n_vocab)
+ self.candidates.sorted = ctypes.c_bool(False)
+ self.candidates.size = ctypes.c_size_t(self.n_vocab)
# Python wrappers over common/common
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]:
+ assert model.model is not None
n_tokens = len(text) + 1 if add_bos else len(text)
result = (llama_cpp.llama_token * n_tokens)()
n_tokens = llama_cpp.llama_tokenize(
@@ -747,7 +754,7 @@ class _LlamaSamplingContext:
ctx_main.sample_repetition_penalties(
token_data_array,
# TODO: Only create this once
- (llama_cpp.llama_token * len(self.prev))(*self.prev), # type: ignore
+ (llama_cpp.llama_token * len(self.prev))(*self.prev),
self.params.penalty_last_n,
self.params.penalty_repeat,
self.params.penalty_freq,
diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index 30cab0a..1226545 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -4,6 +4,8 @@ import os
import sys
import uuid
import time
+import json
+import fnmatch
import multiprocessing
from typing import (
List,
@@ -16,6 +18,7 @@ from typing import (
Callable,
)
from collections import deque
+from pathlib import Path
import ctypes
@@ -29,10 +32,7 @@ from .llama_cache import (
LlamaDiskCache, # type: ignore
LlamaRAMCache, # type: ignore
)
-from .llama_tokenizer import (
- BaseLlamaTokenizer,
- LlamaTokenizer
-)
+from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format
@@ -50,9 +50,7 @@ from ._internals import (
_LlamaSamplingContext, # type: ignore
)
from ._logger import set_verbose
-from ._utils import (
- suppress_stdout_stderr
-)
+from ._utils import suppress_stdout_stderr
class Llama:
@@ -189,7 +187,11 @@ class Llama:
Llama.__backend_initialized = True
if isinstance(numa, bool):
- self.numa = llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
+ self.numa = (
+ llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
+ if numa
+ else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
+ )
else:
self.numa = numa
@@ -246,9 +248,9 @@ class Llama:
else:
raise ValueError(f"Unknown value type for {k}: {v}")
- self._kv_overrides_array[
- -1
- ].key = b"\0" # ensure sentinel element is zeroed
+ self._kv_overrides_array[-1].key = (
+ b"\0" # ensure sentinel element is zeroed
+ )
self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ???
@@ -256,7 +258,7 @@ class Llama:
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)
-
+
# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
@@ -289,7 +291,9 @@ class Llama:
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.mul_mat_q = mul_mat_q
- self.context_params.logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
+ self.context_params.logits_all = (
+ logits_all if draft_model is None else True
+ ) # Must be set to True for speculative decoding
self.context_params.embedding = embedding
self.context_params.offload_kqv = offload_kqv
@@ -379,8 +383,14 @@ class Llama:
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)
- if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
- chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
+ if (
+ self.chat_format is None
+ and self.chat_handler is None
+ and "tokenizer.chat_template" in self.metadata
+ ):
+ chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
+ self.metadata
+ )
if chat_format is not None:
self.chat_format = chat_format
@@ -406,9 +416,7 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
- template=template,
- eos_token=eos_token,
- bos_token=bos_token
+ template=template, eos_token=eos_token, bos_token=bos_token
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None:
@@ -459,7 +467,9 @@ class Llama:
"""
return self.tokenizer_.tokenize(text, add_bos, special)
- def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
+ def detokenize(
+ self, tokens: List[int], prev_tokens: Optional[List[int]] = None
+ ) -> bytes:
"""Detokenize a list of tokens.
Args:
@@ -565,7 +575,7 @@ class Llama:
logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
- else logits_processor(self._input_ids[:idx + 1], logits)
+ else logits_processor(self._input_ids[: idx + 1], logits)
)
sampling_params = _LlamaSamplingParams(
@@ -707,7 +717,9 @@ class Llama:
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
- draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
+ draft_tokens = self.draft_model(
+ self.input_ids[: self.n_tokens + len(tokens)]
+ )
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
@@ -792,6 +804,7 @@ class Llama:
# decode and fetch embeddings
data: List[List[float]] = []
+
def decode_batch(n_seq: int):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
@@ -800,9 +813,9 @@ class Llama:
# store embeddings
for i in range(n_seq):
- embedding: List[float] = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
- :n_embd
- ]
+ embedding: List[float] = llama_cpp.llama_get_embeddings_ith(
+ self._ctx.ctx, i
+ )[:n_embd]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
@@ -1669,12 +1682,13 @@ class Llama:
"""
try:
from openai.types.chat import ChatCompletion, ChatCompletionChunk
- stream = kwargs.get("stream", False) # type: ignore
+
+ stream = kwargs.get("stream", False) # type: ignore
assert isinstance(stream, bool)
if stream:
- return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
+ return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore
else:
- return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
+ return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore
except ImportError:
raise ImportError(
"To use create_chat_completion_openai_v1, you must install the openai package."
@@ -1804,7 +1818,7 @@ class Llama:
self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens
state_size = state.llama_state_size
- LLamaStateArrayType = llama_cpp.c_uint8 * state_size
+ LLamaStateArrayType = ctypes.c_uint8 * state_size
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
@@ -1866,7 +1880,100 @@ class Llama:
break
return longest_prefix
+ @classmethod
+ def from_pretrained(
+ cls,
+ repo_id: str,
+ filename: Optional[str],
+ local_dir: Optional[Union[str, os.PathLike[str]]] = None,
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
+ **kwargs: Any,
+ ) -> "Llama":
+ """Create a Llama model from a pretrained model name or path.
+ This method requires the huggingface-hub package.
+ You can install it with `pip install huggingface-hub`.
+ Args:
+ repo_id: The model repo id.
+ filename: A filename or glob pattern to match the model file in the repo.
+ local_dir: The local directory to save the model to.
+ local_dir_use_symlinks: Whether to use symlinks when downloading the model.
+ **kwargs: Additional keyword arguments to pass to the Llama constructor.
+
+ Returns:
+ A Llama model."""
+ try:
+ from huggingface_hub import hf_hub_download, HfFileSystem
+ from huggingface_hub.utils import validate_repo_id
+ except ImportError:
+ raise ImportError(
+ "Llama.from_pretrained requires the huggingface-hub package. "
+ "You can install it with `pip install huggingface-hub`."
+ )
+
+ validate_repo_id(repo_id)
+
+ hffs = HfFileSystem()
+
+ files = [
+ file["name"] if isinstance(file, dict) else file
+ for file in hffs.ls(repo_id)
+ ]
+
+ # split each file into repo_id, subfolder, filename
+ file_list: List[str] = []
+ for file in files:
+ rel_path = Path(file).relative_to(repo_id)
+ file_list.append(str(rel_path))
+
+ matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
+
+ if len(matching_files) == 0:
+ raise ValueError(
+ f"No file found in {repo_id} that match {filename}\n\n"
+ f"Available Files:\n{json.dumps(file_list)}"
+ )
+
+ if len(matching_files) > 1:
+ raise ValueError(
+ f"Multiple files found in {repo_id} matching {filename}\n\n"
+ f"Available Files:\n{json.dumps(files)}"
+ )
+
+ (matching_file,) = matching_files
+
+ subfolder = str(Path(matching_file).parent)
+ filename = Path(matching_file).name
+
+ # download the file
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ subfolder=subfolder,
+ local_dir=local_dir,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ cache_dir=cache_dir,
+ )
+
+ if local_dir is None:
+ model_path = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ subfolder=subfolder,
+ local_dir=local_dir,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ cache_dir=cache_dir,
+ local_files_only=True,
+
+ )
+ else:
+ model_path = os.path.join(local_dir, filename)
+
+ return cls(
+ model_path=model_path,
+ **kwargs,
+ )
class LlamaState:
diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py
index 8dd0ddf..16bccb9 100644
--- a/llama_cpp/llama_chat_format.py
+++ b/llama_cpp/llama_chat_format.py
@@ -14,6 +14,7 @@ import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
+from ._logger import logger
from ._utils import suppress_stdout_stderr, Singleton
### Common Chat Templates and Special Tokens ###
@@ -993,6 +994,26 @@ def format_saiga(
return ChatFormatterResponse(prompt=_prompt.strip())
+# Chat format for Google's Gemma models, see more details and available models:
+# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
+@register_chat_format("gemma")
+def format_gemma(
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ **kwargs: Any,
+) -> ChatFormatterResponse:
+ system_message = _get_system_message(messages)
+ if system_message is not None and system_message != "":
+ logger.debug(
+ "`role='system'` messages are not allowed on Google's Gemma models."
+ )
+ _roles = dict(user="user\n", assistant="model\n")
+ _sep = "\n"
+ _messages = _map_roles(messages, _roles)
+ _messages.append((_roles["assistant"], None))
+ _prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep)
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
+
+
# Tricky chat formats that require custom chat handlers
diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py
index e946adb..f4d523b 100644
--- a/llama_cpp/llama_cpp.py
+++ b/llama_cpp/llama_cpp.py
@@ -1,27 +1,23 @@
+from __future__ import annotations
+
import sys
import os
import ctypes
-from ctypes import (
- c_bool,
- c_char_p,
- c_int,
- c_int8,
- c_int32,
- c_uint8,
- c_uint32,
- c_int64,
- c_size_t,
- c_float,
- c_double,
- c_void_p,
- POINTER,
- _Pointer, # type: ignore
- Structure,
- Union as CtypesUnion,
- Array,
-)
+import functools
import pathlib
-from typing import List, Union
+
+from typing import (
+ Any,
+ Callable,
+ List,
+ Union,
+ NewType,
+ Optional,
+ TYPE_CHECKING,
+ TypeVar,
+ Generic,
+)
+from typing_extensions import TypeAlias
# Load the library
@@ -71,7 +67,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
- return ctypes.CDLL(str(_lib_path), **cdll_args)
+ return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@@ -86,14 +82,69 @@ _lib_base_name = "llama"
# Load the library
_lib = _load_shared_library(_lib_base_name)
-# Misc
-c_float_p = POINTER(c_float)
-c_uint8_p = POINTER(c_uint8)
-c_size_t_p = POINTER(c_size_t)
+
+# ctypes sane type hint helpers
+#
+# - Generic Pointer and Array types
+# - PointerOrRef type with a type hinted byref function
+#
+# NOTE: Only use these for static type checking not for runtime checks
+# no good will come of that
+
+if TYPE_CHECKING:
+ CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
+
+ CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
+
+ CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
+
+ CtypesVoidPointer: TypeAlias = ctypes.c_void_p
+
+ class CtypesRef(Generic[CtypesCData]):
+ pass
+
+ CtypesPointerOrRef: TypeAlias = Union[
+ CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
+ ]
+
+ CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
+
+
+def ctypes_function_for_shared_library(lib: ctypes.CDLL):
+ def ctypes_function(
+ name: str, argtypes: List[Any], restype: Any, enabled: bool = True
+ ):
+ def decorator(f: Callable[..., Any]):
+ if enabled:
+ func = getattr(lib, name)
+ func.argtypes = argtypes
+ func.restype = restype
+ functools.wraps(f)(func)
+ return func
+ else:
+ return f
+
+ return decorator
+
+ return ctypes_function
+
+
+ctypes_function = ctypes_function_for_shared_library(_lib)
+
+
+def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
+ """Type-annotated version of ctypes.byref"""
+ ...
+
+
+byref = ctypes.byref # type: ignore
+
# from ggml-backend.h
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
-ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(c_bool, c_void_p, c_bool, c_void_p)
+ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
+ ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
+)
# llama.h bindings
@@ -121,19 +172,21 @@ LLAMA_SESSION_VERSION = 4
# struct llama_model;
-llama_model_p = c_void_p
+llama_model_p = NewType("llama_model_p", int)
+llama_model_p_ctypes = ctypes.c_void_p
# struct llama_context;
-llama_context_p = c_void_p
+llama_context_p = NewType("llama_context_p", int)
+llama_context_p_ctypes = ctypes.c_void_p
# typedef int32_t llama_pos;
-llama_pos = c_int32
+llama_pos = ctypes.c_int32
# typedef int32_t llama_token;
-llama_token = c_int32
-llama_token_p = POINTER(llama_token)
+llama_token = ctypes.c_int32
+llama_token_p = ctypes.POINTER(llama_token)
# typedef int32_t llama_seq_id;
-llama_seq_id = c_int32
+llama_seq_id = ctypes.c_int32
# enum llama_vocab_type {
@@ -191,6 +244,7 @@ LLAMA_TOKEN_TYPE_BYTE = 6
# LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors
+# LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
# };
@@ -217,6 +271,7 @@ LLAMA_FTYPE_MOSTLY_Q2_K_S = 21
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23
LLAMA_FTYPE_MOSTLY_IQ1_S = 24
+LLAMA_FTYPE_MOSTLY_IQ4_NL = 25
LLAMA_FTYPE_GUESSED = 1024
# enum llama_rope_scaling_type {
@@ -256,7 +311,7 @@ LLAMA_SPLIT_ROW = 2
# float logit; // log-odds of the token
# float p; // probability of the token
# } llama_token_data;
-class llama_token_data(Structure):
+class llama_token_data(ctypes.Structure):
"""Used to store token data
Attributes:
@@ -266,12 +321,12 @@ class llama_token_data(Structure):
_fields_ = [
("id", llama_token),
- ("logit", c_float),
- ("p", c_float),
+ ("logit", ctypes.c_float),
+ ("p", ctypes.c_float),
]
-llama_token_data_p = POINTER(llama_token_data)
+llama_token_data_p = ctypes.POINTER(llama_token_data)
# typedef struct llama_token_data_array {
@@ -279,7 +334,7 @@ llama_token_data_p = POINTER(llama_token_data)
# size_t size;
# bool sorted;
# } llama_token_data_array;
-class llama_token_data_array(Structure):
+class llama_token_data_array(ctypes.Structure):
"""Used to sample tokens given logits
Attributes:
@@ -289,15 +344,17 @@ class llama_token_data_array(Structure):
_fields_ = [
("data", llama_token_data_p),
- ("size", c_size_t),
- ("sorted", c_bool),
+ ("size", ctypes.c_size_t),
+ ("sorted", ctypes.c_bool),
]
-llama_token_data_array_p = POINTER(llama_token_data_array)
+llama_token_data_array_p = ctypes.POINTER(llama_token_data_array)
# typedef bool (*llama_progress_callback)(float progress, void *ctx);
-llama_progress_callback = ctypes.CFUNCTYPE(c_bool, c_float, c_void_p)
+llama_progress_callback = ctypes.CFUNCTYPE(
+ ctypes.c_bool, ctypes.c_float, ctypes.c_void_p
+)
# // Input data for llama_decode
@@ -330,7 +387,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(c_bool, c_float, c_void_p)
# llama_pos all_pos_1; // used if pos == NULL
# llama_seq_id all_seq_id; // used if seq_id == NULL
# } llama_batch;
-class llama_batch(Structure):
+class llama_batch(ctypes.Structure):
"""Input data for llama_decode
A llama_batch object can contain input about one or many sequences
@@ -339,19 +396,19 @@ class llama_batch(Structure):
Attributes:
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
- embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
+ embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
"""
_fields_ = [
- ("n_tokens", c_int32),
- ("token", POINTER(llama_token)),
- ("embd", c_float_p),
- ("pos", POINTER(llama_pos)),
- ("n_seq_id", POINTER(c_int32)),
- ("seq_id", POINTER(POINTER(llama_seq_id))),
- ("logits", POINTER(c_int8)),
+ ("n_tokens", ctypes.c_int32),
+ ("token", ctypes.POINTER(llama_token)),
+ ("embd", ctypes.POINTER(ctypes.c_float)),
+ ("pos", ctypes.POINTER(llama_pos)),
+ ("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
+ ("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
+ ("logits", ctypes.POINTER(ctypes.c_int8)),
("all_pos_0", llama_pos),
("all_pos_1", llama_pos),
("all_seq_id", llama_seq_id),
@@ -377,18 +434,18 @@ LLAMA_KV_OVERRIDE_BOOL = 2
# bool bool_value;
# };
# };
-class llama_model_kv_override_value(CtypesUnion):
+class llama_model_kv_override_value(ctypes.Union):
_fields_ = [
- ("int_value", c_int64),
- ("float_value", c_double),
- ("bool_value", c_bool),
+ ("int_value", ctypes.c_int64),
+ ("float_value", ctypes.c_double),
+ ("bool_value", ctypes.c_bool),
]
-class llama_model_kv_override(Structure):
+class llama_model_kv_override(ctypes.Structure):
_fields_ = [
("key", ctypes.c_char * 128),
- ("tag", c_int),
+ ("tag", ctypes.c_int),
("value", llama_model_kv_override_value),
]
@@ -423,32 +480,32 @@ class llama_model_kv_override(Structure):
# bool use_mmap; // use mmap if possible
# bool use_mlock; // force system to keep model in RAM
# };
-class llama_model_params(Structure):
+class llama_model_params(ctypes.Structure):
"""Parameters for llama_model
Attributes:
n_gpu_layers (int): number of layers to store in VRAM
split_mode (int): how to split the model across multiple GPUs
main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored
- tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
+ tensor_split (ctypes.Array[ctypes.ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted.
- progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
+ progress_callback_user_data (ctypes.ctypes.c_void_p): context pointer passed to the progress callback
kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data
vocab_only (bool): only load the vocabulary, no weights
use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM"""
_fields_ = [
- ("n_gpu_layers", c_int32),
- ("split_mode", c_int),
- ("main_gpu", c_int32),
- ("tensor_split", c_float_p),
+ ("n_gpu_layers", ctypes.c_int32),
+ ("split_mode", ctypes.c_int),
+ ("main_gpu", ctypes.c_int32),
+ ("tensor_split", ctypes.POINTER(ctypes.c_float)),
("progress_callback", llama_progress_callback),
- ("progress_callback_user_data", c_void_p),
- ("kv_overrides", POINTER(llama_model_kv_override)),
- ("vocab_only", c_bool),
- ("use_mmap", c_bool),
- ("use_mlock", c_bool),
+ ("progress_callback_user_data", ctypes.c_void_p),
+ ("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
+ ("vocab_only", ctypes.c_bool),
+ ("use_mmap", ctypes.c_bool),
+ ("use_mlock", ctypes.c_bool),
]
@@ -483,7 +540,7 @@ class llama_model_params(Structure):
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
# bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
# };
-class llama_context_params(Structure):
+class llama_context_params(ctypes.Structure):
"""Parameters for llama_context
Attributes:
@@ -501,7 +558,7 @@ class llama_context_params(Structure):
yarn_beta_slow (float): YaRN high correction dim
yarn_orig_ctx (int): YaRN original context size
cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval
- cb_eval_user_data (ctypes.c_void_p): user data for cb_eval
+ cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
type_k (int): data type for K cache
type_v (int): data type for V cache
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
@@ -512,28 +569,28 @@ class llama_context_params(Structure):
"""
_fields_ = [
- ("seed", c_uint32),
- ("n_ctx", c_uint32),
- ("n_batch", c_uint32),
- ("n_threads", c_uint32),
- ("n_threads_batch", c_uint32),
- ("rope_scaling_type", c_int32),
- ("rope_freq_base", c_float),
- ("rope_freq_scale", c_float),
- ("yarn_ext_factor", c_float),
- ("yarn_attn_factor", c_float),
- ("yarn_beta_fast", c_float),
- ("yarn_beta_slow", c_float),
- ("yarn_orig_ctx", c_uint32),
+ ("seed", ctypes.c_uint32),
+ ("n_ctx", ctypes.c_uint32),
+ ("n_batch", ctypes.c_uint32),
+ ("n_threads", ctypes.c_uint32),
+ ("n_threads_batch", ctypes.c_uint32),
+ ("rope_scaling_type", ctypes.c_int32),
+ ("rope_freq_base", ctypes.c_float),
+ ("rope_freq_scale", ctypes.c_float),
+ ("yarn_ext_factor", ctypes.c_float),
+ ("yarn_attn_factor", ctypes.c_float),
+ ("yarn_beta_fast", ctypes.c_float),
+ ("yarn_beta_slow", ctypes.c_float),
+ ("yarn_orig_ctx", ctypes.c_uint32),
("cb_eval", ggml_backend_sched_eval_callback),
- ("cb_eval_user_data", c_void_p),
- ("type_k", c_int),
- ("type_v", c_int),
- ("mul_mat_q", c_bool),
- ("logits_all", c_bool),
- ("embedding", c_bool),
- ("offload_kqv", c_bool),
- ("do_pooling", c_bool),
+ ("cb_eval_user_data", ctypes.c_void_p),
+ ("type_k", ctypes.c_int),
+ ("type_v", ctypes.c_int),
+ ("mul_mat_q", ctypes.c_bool),
+ ("logits_all", ctypes.c_bool),
+ ("embedding", ctypes.c_bool),
+ ("offload_kqv", ctypes.c_bool),
+ ("do_pooling", ctypes.c_bool),
]
@@ -543,7 +600,9 @@ class llama_context_params(Structure):
# // if it exists.
# // It might not exist for progress report where '.' is output repeatedly.
# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
-llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p)
+llama_log_callback = ctypes.CFUNCTYPE(
+ None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p
+)
"""Signature for logging events
Note that text includes the new line character at the end for most events.
If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
@@ -561,7 +620,7 @@ It might not exist for progress report where '.' is output repeatedly."""
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
# void * imatrix; // pointer to importance matrix data
# } llama_model_quantize_params;
-class llama_model_quantize_params(Structure):
+class llama_model_quantize_params(ctypes.Structure):
"""Parameters for llama_model_quantize
Attributes:
@@ -571,23 +630,23 @@ class llama_model_quantize_params(Structure):
quantize_output_tensor (bool): quantize output.weight
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
pure (bool): disable k-quant mixtures and quantize all tensors to the same type
- imatrix (ctypes.c_void_p): pointer to importance matrix data
+ imatrix (ctypes.ctypes.c_void_p): pointer to importance matrix data
"""
_fields_ = [
- ("nthread", c_int32),
- ("ftype", c_int),
- ("allow_requantize", c_bool),
- ("quantize_output_tensor", c_bool),
- ("only_copy", c_bool),
- ("pure", c_bool),
- ("imatrix", c_void_p),
+ ("nthread", ctypes.c_int32),
+ ("ftype", ctypes.c_int),
+ ("allow_requantize", ctypes.c_bool),
+ ("quantize_output_tensor", ctypes.c_bool),
+ ("only_copy", ctypes.c_bool),
+ ("pure", ctypes.c_bool),
+ ("imatrix", ctypes.c_void_p),
]
# // grammar types
# struct llama_grammar;
-llama_grammar_p = c_void_p
+llama_grammar_p = ctypes.c_void_p
# // grammar element type
# enum llama_gretype {
@@ -627,14 +686,14 @@ LLAMA_GRETYPE_CHAR_ALT = 6
# enum llama_gretype type;
# uint32_t value; // Unicode code point or rule ID
# } llama_grammar_element;
-class llama_grammar_element(Structure):
+class llama_grammar_element(ctypes.Structure):
_fields_ = [
- ("type", c_int),
- ("value", c_uint32),
+ ("type", ctypes.c_int),
+ ("value", ctypes.c_uint32),
]
-llama_grammar_element_p = POINTER(llama_grammar_element)
+llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
# // performance timing information
# struct llama_timings {
@@ -650,17 +709,17 @@ llama_grammar_element_p = POINTER(llama_grammar_element)
# int32_t n_p_eval;
# int32_t n_eval;
# };
-class llama_timings(Structure):
+class llama_timings(ctypes.Structure):
_fields_ = [
- ("t_start_ms", c_double),
- ("t_end_ms", c_double),
- ("t_load_ms", c_double),
- ("t_sample_ms", c_double),
- ("t_p_eval_ms", c_double),
- ("t_eval_ms", c_double),
- ("n_sample", c_int32),
- ("n_p_eval", c_int32),
- ("n_eval", c_int32),
+ ("t_start_ms", ctypes.c_double),
+ ("t_end_ms", ctypes.c_double),
+ ("t_load_ms", ctypes.c_double),
+ ("t_sample_ms", ctypes.c_double),
+ ("t_p_eval_ms", ctypes.c_double),
+ ("t_eval_ms", ctypes.c_double),
+ ("n_sample", ctypes.c_int32),
+ ("n_p_eval", ctypes.c_int32),
+ ("n_eval", ctypes.c_int32),
]
@@ -669,42 +728,45 @@ class llama_timings(Structure):
# const char * role;
# const char * content;
# } llama_chat_message;
-class llama_chat_message(Structure):
+class llama_chat_message(ctypes.Structure):
_fields_ = [
- ("role", c_char_p),
- ("content", c_char_p),
+ ("role", ctypes.c_char_p),
+ ("content", ctypes.c_char_p),
]
# // Helpers for getting default parameters
# LLAMA_API struct llama_model_params llama_model_default_params(void);
+@ctypes_function(
+ "llama_model_default_params",
+ [],
+ llama_model_params,
+)
def llama_model_default_params() -> llama_model_params:
"""Get default parameters for llama_model"""
- return _lib.llama_model_default_params()
-
-
-_lib.llama_model_default_params.argtypes = []
-_lib.llama_model_default_params.restype = llama_model_params
+ ...
# LLAMA_API struct llama_context_params llama_context_default_params(void);
+@ctypes_function(
+ "llama_context_default_params",
+ [],
+ llama_context_params,
+)
def llama_context_default_params() -> llama_context_params:
"""Get default parameters for llama_context"""
- return _lib.llama_context_default_params()
-
-
-_lib.llama_context_default_params.argtypes = []
-_lib.llama_context_default_params.restype = llama_context_params
+ ...
# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
+@ctypes_function(
+ "llama_model_quantize_default_params",
+ [],
+ llama_model_quantize_params,
+)
def llama_model_quantize_default_params() -> llama_model_quantize_params:
"""Get default parameters for llama_model_quantize"""
- return _lib.llama_model_quantize_default_params()
-
-
-_lib.llama_model_quantize_default_params.argtypes = []
-_lib.llama_model_quantize_default_params.restype = llama_model_quantize_params
+ ...
# // Initialize the llama + ggml backend
@@ -712,15 +774,16 @@ _lib.llama_model_quantize_default_params.restype = llama_model_quantize_params
# // Call once at the start of the program
# LLAMA_API void llama_backend_init(bool numa);
# LLAMA_API void llama_backend_init(void);
+@ctypes_function(
+ "llama_backend_init",
+ [],
+ None,
+)
def llama_backend_init():
"""Initialize the llama + ggml backend
If numa is true, use NUMA optimizations
Call once at the start of the program"""
- return _lib.llama_backend_init()
-
-
-_lib.llama_backend_init.argtypes = []
-_lib.llama_backend_init.restype = None
+ ...
# // numa strategies
@@ -742,207 +805,201 @@ GGML_NUMA_STRATEGY_COUNT = 5
# //optional:
# LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
-def llama_numa_init(numa: int):
- return _lib.llama_numa_init(numa)
-
-
-_lib.llama_numa_init.argtypes = [c_int]
-_lib.llama_numa_init.restype = None
+@ctypes_function(
+ "llama_numa_init",
+ [ctypes.c_int],
+ None,
+)
+def llama_numa_init(numa: int, /):
+ ...
# // Call once at the end of the program - currently only used for MPI
# LLAMA_API void llama_backend_free(void);
+@ctypes_function(
+ "llama_backend_free",
+ [],
+ None,
+)
def llama_backend_free():
"""Call once at the end of the program - currently only used for MPI"""
- return _lib.llama_backend_free()
-
-
-_lib.llama_backend_free.argtypes = []
-_lib.llama_backend_free.restype = None
+ ...
# LLAMA_API struct llama_model * llama_load_model_from_file(
# const char * path_model,
# struct llama_model_params params);
+@ctypes_function(
+ "llama_load_model_from_file",
+ [ctypes.c_char_p, llama_model_params],
+ llama_model_p_ctypes,
+)
def llama_load_model_from_file(
- path_model: bytes, params: llama_model_params
-) -> llama_model_p:
- return _lib.llama_load_model_from_file(path_model, params)
-
-
-_lib.llama_load_model_from_file.argtypes = [c_char_p, llama_model_params]
-_lib.llama_load_model_from_file.restype = llama_model_p
+ path_model: bytes, params: llama_model_params, /
+) -> Optional[llama_model_p]:
+ ...
# LLAMA_API void llama_free_model(struct llama_model * model);
-def llama_free_model(model: llama_model_p):
- return _lib.llama_free_model(model)
-
-
-_lib.llama_free_model.argtypes = [llama_model_p]
-_lib.llama_free_model.restype = None
+@ctypes_function(
+ "llama_free_model",
+ [llama_model_p_ctypes],
+ None,
+)
+def llama_free_model(model: llama_model_p, /):
+ ...
# LLAMA_API struct llama_context * llama_new_context_with_model(
# struct llama_model * model,
# struct llama_context_params params);
+@ctypes_function(
+ "llama_new_context_with_model",
+ [llama_model_p_ctypes, llama_context_params],
+ llama_context_p_ctypes,
+)
def llama_new_context_with_model(
- model: llama_model_p, params: llama_context_params
-) -> llama_context_p:
- return _lib.llama_new_context_with_model(model, params)
-
-
-_lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_params]
-_lib.llama_new_context_with_model.restype = llama_context_p
+ model: llama_model_p, params: llama_context_params, /
+) -> Optional[llama_context_p]:
+ ...
# // Frees all allocated memory
# LLAMA_API void llama_free(struct llama_context * ctx);
-def llama_free(ctx: llama_context_p):
+@ctypes_function(
+ "llama_free",
+ [llama_context_p_ctypes],
+ None,
+)
+def llama_free(ctx: llama_context_p, /):
"""Frees all allocated memory"""
- return _lib.llama_free(ctx)
-
-
-_lib.llama_free.argtypes = [llama_context_p]
-_lib.llama_free.restype = None
+ ...
# LLAMA_API int64_t llama_time_us(void);
+@ctypes_function(
+ "llama_time_us",
+ [],
+ ctypes.c_int64,
+)
def llama_time_us() -> int:
- return _lib.llama_time_us()
-
-
-_lib.llama_time_us.argtypes = []
-_lib.llama_time_us.restype = ctypes.c_int64
+ ...
# LLAMA_API size_t llama_max_devices(void);
+
+
+@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int:
- return _lib.llama_max_devices()
-
-
-_lib.llama_max_devices.argtypes = []
-_lib.llama_max_devices.restype = ctypes.c_size_t
+ ...
# LLAMA_API bool llama_supports_mmap (void);
+
+
+@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool:
- return _lib.llama_supports_mmap()
-
-
-_lib.llama_supports_mmap.argtypes = []
-_lib.llama_supports_mmap.restype = c_bool
+ ...
# LLAMA_API bool llama_supports_mlock (void);
+
+
+@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool:
- return _lib.llama_supports_mlock()
-
-
-_lib.llama_supports_mlock.argtypes = []
-_lib.llama_supports_mlock.restype = c_bool
+ ...
# LLAMA_API bool llama_supports_gpu_offload(void);
+
+
+@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool:
- return _lib.llama_supports_gpu_offload()
-
-
-_lib.llama_supports_gpu_offload.argtypes = []
-_lib.llama_supports_gpu_offload.restype = c_bool
+ ...
# LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
+
+
+@ctypes_function("llama_mmap_supported", [], ctypes.c_bool)
def llama_mmap_supported() -> bool:
- return _lib.llama_mmap_supported()
-
-
-_lib.llama_mmap_supported.argtypes = []
-_lib.llama_mmap_supported.restype = c_bool
+ ...
# LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
+
+
+@ctypes_function("llama_mlock_supported", [], ctypes.c_bool)
def llama_mlock_supported() -> bool:
- return _lib.llama_mlock_supported()
-
-
-_lib.llama_mlock_supported.argtypes = []
-_lib.llama_mlock_supported.restype = c_bool
+ ...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
-def llama_get_model(ctx: llama_context_p) -> llama_model_p:
- return _lib.llama_get_model(ctx)
-_lib.llama_get_model.argtypes = [llama_context_p]
-_lib.llama_get_model.restype = llama_model_p
+@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
+def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
+ ...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
-def llama_n_ctx(ctx: llama_context_p) -> int:
- return _lib.llama_n_ctx(ctx)
-_lib.llama_n_ctx.argtypes = [llama_context_p]
-_lib.llama_n_ctx.restype = c_uint32
+@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
+def llama_n_ctx(ctx: llama_context_p, /) -> int:
+ ...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
-def llama_n_batch(ctx: llama_context_p) -> int:
- return _lib.llama_n_batch(ctx)
-_lib.llama_n_batch.argtypes = [llama_context_p]
-_lib.llama_n_batch.restype = c_uint32
+@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
+def llama_n_batch(ctx: llama_context_p, /) -> int:
+ ...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
-def llama_vocab_type(model: llama_model_p) -> int:
- return _lib.llama_vocab_type(model)
-_lib.llama_vocab_type.argtypes = [llama_model_p]
-_lib.llama_vocab_type.restype = c_int
+@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
+def llama_vocab_type(model: llama_model_p, /) -> int:
+ ...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
-def llama_n_vocab(model: llama_model_p) -> int:
- return _lib.llama_n_vocab(model)
-_lib.llama_n_vocab.argtypes = [llama_model_p]
-_lib.llama_n_vocab.restype = c_int32
+@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
+def llama_n_vocab(model: llama_model_p, /) -> int:
+ ...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
-def llama_n_ctx_train(model: llama_model_p) -> int:
- return _lib.llama_n_ctx_train(model)
-_lib.llama_n_ctx_train.argtypes = [llama_model_p]
-_lib.llama_n_ctx_train.restype = c_int32
+@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
+def llama_n_ctx_train(model: llama_model_p, /) -> int:
+ ...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
-def llama_n_embd(model: llama_model_p) -> int:
- return _lib.llama_n_embd(model)
-_lib.llama_n_embd.argtypes = [llama_model_p]
-_lib.llama_n_embd.restype = c_int32
+@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
+def llama_n_embd(model: llama_model_p, /) -> int:
+ ...
# // Get the model's RoPE frequency scaling factor
# LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
-def llama_rope_freq_scale_train(model: llama_model_p) -> float:
+
+
+@ctypes_function("llama_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float)
+def llama_rope_freq_scale_train(model: llama_model_p, /) -> float:
"""Get the model's RoPE frequency scaling factor"""
- return _lib.llama_rope_freq_scale_train(model)
+ ...
-_lib.llama_rope_freq_scale_train.argtypes = [llama_model_p]
-_lib.llama_rope_freq_scale_train.restype = c_float
-
# // Functions to access the model's GGUF metadata scalar values
# // - The functions return the length of the string on success, or -1 on failure
# // - The output string is always null-terminated and cleared on failure
@@ -951,110 +1008,140 @@ _lib.llama_rope_freq_scale_train.restype = c_float
# // Get metadata value as a string by key name
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
+
+
+@ctypes_function(
+ "llama_model_meta_val_str",
+ [
+ llama_model_p_ctypes,
+ ctypes.c_char_p,
+ ctypes.c_char_p,
+ ctypes.c_size_t,
+ ],
+ ctypes.c_int32,
+)
def llama_model_meta_val_str(
- model: llama_model_p, key: Union[c_char_p, bytes], buf: bytes, buf_size: int
+ model: llama_model_p,
+ key: Union[ctypes.c_char_p, bytes],
+ buf: bytes,
+ buf_size: int,
+ /,
) -> int:
"""Get metadata value as a string by key name"""
- return _lib.llama_model_meta_val_str(model, key, buf, buf_size)
-
-
-_lib.llama_model_meta_val_str.argtypes = [llama_model_p, c_char_p, c_char_p, c_size_t]
-_lib.llama_model_meta_val_str.restype = c_int32
+ ...
# // Get the number of metadata key/value pairs
# LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
-def llama_model_meta_count(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_model_meta_count", [llama_model_p_ctypes], ctypes.c_int32)
+def llama_model_meta_count(model: llama_model_p, /) -> int:
"""Get the number of metadata key/value pairs"""
- return _lib.llama_model_meta_count(model)
-
-
-_lib.llama_model_meta_count.argtypes = [llama_model_p]
-_lib.llama_model_meta_count.restype = c_int32
+ ...
# // Get metadata key name by index
# LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
+
+
+@ctypes_function(
+ "llama_model_meta_key_by_index",
+ [
+ llama_model_p_ctypes,
+ ctypes.c_int32,
+ ctypes.c_char_p,
+ ctypes.c_size_t,
+ ],
+ ctypes.c_int32,
+)
def llama_model_meta_key_by_index(
- model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int
+ model: llama_model_p,
+ i: Union[ctypes.c_int, int],
+ buf: Union[bytes, CtypesArray[ctypes.c_char]],
+ buf_size: int,
+ /,
) -> int:
"""Get metadata key name by index"""
- return _lib.llama_model_meta_key_by_index(model, i, buf, buf_size)
-
-
-_lib.llama_model_meta_key_by_index.argtypes = [
- llama_model_p,
- c_int32,
- c_char_p,
- c_size_t,
-]
-_lib.llama_model_meta_key_by_index.restype = c_int32
+ ...
# // Get metadata value as a string by index
# LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
+
+
+@ctypes_function(
+ "llama_model_meta_val_str_by_index",
+ [
+ llama_model_p_ctypes,
+ ctypes.c_int32,
+ ctypes.c_char_p,
+ ctypes.c_size_t,
+ ],
+ ctypes.c_int32,
+)
def llama_model_meta_val_str_by_index(
- model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int
+ model: llama_model_p,
+ i: Union[ctypes.c_int, int],
+ buf: Union[bytes, CtypesArray[ctypes.c_char]],
+ buf_size: int,
+ /,
) -> int:
"""Get metadata value as a string by index"""
- return _lib.llama_model_meta_val_str_by_index(model, i, buf, buf_size)
-
-
-_lib.llama_model_meta_val_str_by_index.argtypes = [
- llama_model_p,
- c_int32,
- c_char_p,
- c_size_t,
-]
-_lib.llama_model_meta_val_str_by_index.restype = c_int32
+ ...
# // Get a string describing the model type
# LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
+
+
+@ctypes_function(
+ "llama_model_desc",
+ [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_size_t],
+ ctypes.c_int32,
+)
def llama_model_desc(
- model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int]
+ model: llama_model_p,
+ buf: Union[bytes, CtypesArray[ctypes.c_char]],
+ buf_size: Union[ctypes.c_size_t, int],
+ /,
) -> int:
"""Get a string describing the model type"""
- return _lib.llama_model_desc(model, buf, buf_size)
-
-
-_lib.llama_model_desc.argtypes = [llama_model_p, c_char_p, c_size_t]
-_lib.llama_model_desc.restype = c_int32
+ ...
# // Returns the total size of all the tensors in the model in bytes
# LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
-def llama_model_size(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_model_size", [llama_model_p_ctypes], ctypes.c_uint64)
+def llama_model_size(model: llama_model_p, /) -> int:
"""Returns the total size of all the tensors in the model in bytes"""
- return _lib.llama_model_size(model)
-
-
-_lib.llama_model_size.argtypes = [llama_model_p]
-_lib.llama_model_size.restype = ctypes.c_uint64
+ ...
# // Returns the total number of parameters in the model
# LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
-def llama_model_n_params(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_model_n_params", [llama_model_p_ctypes], ctypes.c_uint64)
+def llama_model_n_params(model: llama_model_p, /) -> int:
"""Returns the total number of parameters in the model"""
- return _lib.llama_model_n_params(model)
-
-
-_lib.llama_model_n_params.argtypes = [llama_model_p]
-_lib.llama_model_n_params.restype = ctypes.c_uint64
+ ...
# // Get a llama model tensor
# LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
+
+
+@ctypes_function(
+ "llama_get_model_tensor", [llama_model_p_ctypes, ctypes.c_char_p], ctypes.c_void_p
+)
def llama_get_model_tensor(
- model: llama_model_p, name: Union[c_char_p, bytes]
-) -> c_void_p:
+ model: llama_model_p, name: Union[ctypes.c_char_p, bytes], /
+) -> ctypes.c_void_p:
"""Get a llama model tensor"""
- return _lib.llama_get_model_tensor(model, name)
-
-
-_lib.llama_get_model_tensor.argtypes = [llama_model_p, c_char_p]
-_lib.llama_get_model_tensor.restype = c_void_p
+ ...
# // Returns 0 on success
@@ -1062,21 +1149,25 @@ _lib.llama_get_model_tensor.restype = c_void_p
# const char * fname_inp,
# const char * fname_out,
# const llama_model_quantize_params * params);
+
+
+@ctypes_function(
+ "llama_model_quantize",
+ [
+ ctypes.c_char_p,
+ ctypes.c_char_p,
+ ctypes.POINTER(llama_model_quantize_params),
+ ],
+ ctypes.c_uint32,
+)
def llama_model_quantize(
fname_inp: bytes,
fname_out: bytes,
- params, # type: POINTER(llama_model_quantize_params) # type: ignore
+ params: CtypesPointerOrRef[llama_model_quantize_params],
+ /,
) -> int:
"""Returns 0 on success"""
- return _lib.llama_model_quantize(fname_inp, fname_out, params)
-
-
-_lib.llama_model_quantize.argtypes = [
- c_char_p,
- c_char_p,
- POINTER(llama_model_quantize_params),
-]
-_lib.llama_model_quantize.restype = c_uint32
+ ...
# // Apply a LoRA adapter to a loaded model
@@ -1092,12 +1183,26 @@ _lib.llama_model_quantize.restype = c_uint32
# const char * path_base_model,
# int32_t n_threads),
# "use llama_model_apply_lora_from_file instead");
+
+
+@ctypes_function(
+ "llama_apply_lora_from_file",
+ [
+ llama_context_p_ctypes,
+ ctypes.c_char_p,
+ ctypes.c_float,
+ ctypes.c_char_p,
+ ctypes.c_int32,
+ ],
+ ctypes.c_int32,
+)
def llama_apply_lora_from_file(
ctx: llama_context_p,
- path_lora: Union[c_char_p, bytes],
- scale: Union[c_float, float],
- path_base_model: Union[c_char_p, bytes],
- n_threads: Union[c_int, int],
+ path_lora: Union[ctypes.c_char_p, bytes],
+ scale: Union[ctypes.c_float, float],
+ path_base_model: Union[ctypes.c_char_p, bytes],
+ n_threads: Union[ctypes.c_int32, int],
+ /,
) -> int:
"""Apply a LoRA adapter to a loaded model
path_base_model is the path to a higher quality model to use as a base for
@@ -1105,19 +1210,7 @@ def llama_apply_lora_from_file(
The model needs to be reloaded before applying a new adapter, otherwise the adapter
will be applied on top of the previous one
Returns 0 on success"""
- return _lib.llama_apply_lora_from_file(
- ctx, path_lora, scale, path_base_model, n_threads
- )
-
-
-_lib.llama_apply_lora_from_file.argtypes = [
- llama_context_p,
- c_char_p,
- c_float,
- c_char_p,
- c_int32,
-]
-_lib.llama_apply_lora_from_file.restype = c_int32
+ ...
# LLAMA_API int32_t llama_model_apply_lora_from_file(
@@ -1126,27 +1219,30 @@ _lib.llama_apply_lora_from_file.restype = c_int32
# float scale,
# const char * path_base_model,
# int32_t n_threads);
+
+
+@ctypes_function(
+ "llama_model_apply_lora_from_file",
+ [
+ llama_model_p_ctypes,
+ ctypes.c_char_p,
+ ctypes.c_float,
+ ctypes.c_char_p,
+ ctypes.c_int32,
+ ],
+ ctypes.c_int32,
+)
def llama_model_apply_lora_from_file(
model: llama_model_p,
- path_lora: Union[c_char_p, bytes],
- scale: Union[c_float, float],
- path_base_model: Union[c_char_p, bytes],
- n_threads: Union[c_int, int],
+ path_lora: Union[ctypes.c_char_p, bytes],
+ scale: Union[ctypes.c_float, float],
+ path_base_model: Union[ctypes.c_char_p, bytes],
+ n_threads: Union[ctypes.c_int32, int],
+ /,
) -> int:
- return _lib.llama_model_apply_lora_from_file(
- model, path_lora, scale, path_base_model, n_threads
- )
+ ...
-_lib.llama_model_apply_lora_from_file.argtypes = [
- llama_model_p,
- c_char_p,
- c_float,
- c_char_p,
- c_int32,
-]
-_lib.llama_model_apply_lora_from_file.restype = c_int32
-
# //
# // KV cache
# //
@@ -1158,7 +1254,7 @@ _lib.llama_model_apply_lora_from_file.restype = c_int32
# // May be negative if the cell is not populated.
# llama_pos pos;
# };
-class llama_kv_cache_view_cell(Structure):
+class llama_kv_cache_view_cell(ctypes.Structure):
_fields_ = [("pos", llama_pos)]
@@ -1194,92 +1290,96 @@ class llama_kv_cache_view_cell(Structure):
# // The sequences for each cell. There will be n_max_seq items per cell.
# llama_seq_id * cells_sequences;
# };
-class llama_kv_cache_view(Structure):
+class llama_kv_cache_view(ctypes.Structure):
_fields_ = [
- ("n_cells", c_int32),
- ("n_max_seq", c_int32),
- ("token_count", c_int32),
- ("used_cells", c_int32),
- ("max_contiguous", c_int32),
- ("max_contiguous_idx", c_int32),
- ("cells", POINTER(llama_kv_cache_view_cell)),
- ("cells_sequences", POINTER(llama_seq_id)),
+ ("n_cells", ctypes.c_int32),
+ ("n_max_seq", ctypes.c_int32),
+ ("token_count", ctypes.c_int32),
+ ("used_cells", ctypes.c_int32),
+ ("max_contiguous", ctypes.c_int32),
+ ("max_contiguous_idx", ctypes.c_int32),
+ ("cells", ctypes.POINTER(llama_kv_cache_view_cell)),
+ ("cells_sequences", ctypes.POINTER(llama_seq_id)),
]
-llama_kv_cache_view_p = POINTER(llama_kv_cache_view)
+llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view)
# // Create an empty KV cache view. (use only for debugging purposes)
# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
+
+
+@ctypes_function(
+ "llama_kv_cache_view_init",
+ [llama_context_p_ctypes, ctypes.c_int32],
+ llama_kv_cache_view,
+)
def llama_kv_cache_view_init(
- ctx: llama_context_p, n_max_seq: Union[c_int32, int]
+ ctx: llama_context_p, n_max_seq: Union[ctypes.c_int32, int], /
) -> llama_kv_cache_view:
"""Create an empty KV cache view. (use only for debugging purposes)"""
- return _lib.llama_kv_cache_view_init(ctx, n_max_seq)
-
-
-_lib.llama_kv_cache_view_init.argtypes = [llama_context_p, c_int32]
-_lib.llama_kv_cache_view_init.restype = llama_kv_cache_view
+ ...
# // Free a KV cache view. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
-def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]"): # type: ignore
+
+
+@ctypes_function("llama_kv_cache_view_free", [llama_kv_cache_view_p], None)
+def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore
"""Free a KV cache view. (use only for debugging purposes)"""
- return _lib.llama_kv_cache_view_free(view)
-
-
-_lib.llama_kv_cache_view_free.argtypes = [llama_kv_cache_view_p]
-_lib.llama_kv_cache_view_free.restype = None
+ ...
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
-def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]"): # type: ignore
+
+
+@ctypes_function(
+ "llama_kv_cache_view_update", [llama_context_p_ctypes, llama_kv_cache_view_p], None
+)
+def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
- return _lib.llama_kv_cache_view_update(ctx, view)
-
-
-_lib.llama_kv_cache_view_update.argtypes = [llama_context_p, llama_kv_cache_view_p]
-_lib.llama_kv_cache_view_update.restype = None
+ ...
# // Returns the number of tokens in the KV cache (slow, use only for debug)
# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
# LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
-def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
+
+
+@ctypes_function(
+ "llama_get_kv_cache_token_count", [llama_context_p_ctypes], ctypes.c_int32
+)
+def llama_get_kv_cache_token_count(ctx: llama_context_p, /) -> int:
"""Returns the number of tokens in the KV cache (slow, use only for debug)
If a KV cell has multiple sequences assigned to it, it will be counted multiple times
"""
- return _lib.llama_get_kv_cache_token_count(ctx)
-
-
-_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
-_lib.llama_get_kv_cache_token_count.restype = c_int32
+ ...
# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
# LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
-def llama_get_kv_cache_used_cells(ctx: llama_context_p) -> int:
+
+
+@ctypes_function(
+ "llama_get_kv_cache_used_cells", [llama_context_p_ctypes], ctypes.c_int32
+)
+def llama_get_kv_cache_used_cells(ctx: llama_context_p, /) -> int:
"""Returns the number of used KV cells (i.e. have at least one sequence assigned to them)"""
- return _lib.llama_get_kv_cache_used_cells(ctx)
-
-
-_lib.llama_get_kv_cache_used_cells.argtypes = [llama_context_p]
-_lib.llama_get_kv_cache_used_cells.restype = c_int32
+ ...
# // Clear the KV cache
# LLAMA_API void llama_kv_cache_clear(
# struct llama_context * ctx);
-def llama_kv_cache_clear(ctx: llama_context_p):
+
+
+@ctypes_function("llama_kv_cache_clear", [llama_context_p_ctypes], None)
+def llama_kv_cache_clear(ctx: llama_context_p, /):
"""Clear the KV cache"""
- return _lib.llama_kv_cache_clear(ctx)
-
-
-_lib.llama_kv_cache_clear.argtypes = [llama_context_p]
-_lib.llama_kv_cache_clear.restype = None
+ ...
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -1291,26 +1391,30 @@ _lib.llama_kv_cache_clear.restype = None
# llama_seq_id seq_id,
# llama_pos p0,
# llama_pos p1);
+
+
+@ctypes_function(
+ "llama_kv_cache_seq_rm",
+ [
+ llama_context_p_ctypes,
+ llama_seq_id,
+ llama_pos,
+ llama_pos,
+ ],
+ None,
+)
def llama_kv_cache_seq_rm(
ctx: llama_context_p,
seq_id: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
+ /,
):
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
seq_id < 0 : match any sequence
p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)"""
- return _lib.llama_kv_cache_seq_rm(ctx, seq_id, p0, p1)
-
-
-_lib.llama_kv_cache_seq_rm.argtypes = [
- llama_context_p,
- llama_seq_id,
- llama_pos,
- llama_pos,
-]
-_lib.llama_kv_cache_seq_rm.restype = None
+ ...
# // Copy all tokens that belong to the specified sequence to another sequence
@@ -1323,44 +1427,46 @@ _lib.llama_kv_cache_seq_rm.restype = None
# llama_seq_id seq_id_dst,
# llama_pos p0,
# llama_pos p1);
+
+
+@ctypes_function(
+ "llama_kv_cache_seq_cp",
+ [
+ llama_context_p_ctypes,
+ llama_seq_id,
+ llama_seq_id,
+ llama_pos,
+ llama_pos,
+ ],
+ None,
+)
def llama_kv_cache_seq_cp(
ctx: llama_context_p,
seq_id_src: Union[llama_seq_id, int],
seq_id_dst: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
+ /,
):
"""Copy all tokens that belong to the specified sequence to another sequence
Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)"""
- return _lib.llama_kv_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1)
-
-
-_lib.llama_kv_cache_seq_cp.argtypes = [
- llama_context_p,
- llama_seq_id,
- llama_seq_id,
- llama_pos,
- llama_pos,
-]
-_lib.llama_kv_cache_seq_cp.restype = None
+ ...
# // Removes all tokens that do not belong to the specified sequence
# LLAMA_API void llama_kv_cache_seq_keep(
# struct llama_context * ctx,
# llama_seq_id seq_id);
-def llama_kv_cache_seq_keep(
- ctx: llama_context_p,
- seq_id: Union[llama_seq_id, int],
-):
+
+
+@ctypes_function(
+ "llama_kv_cache_seq_keep", [llama_context_p_ctypes, llama_seq_id], None
+)
+def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /):
"""Removes all tokens that do not belong to the specified sequence"""
- return _lib.llama_kv_cache_seq_keep(ctx, seq_id)
-
-
-_lib.llama_kv_cache_seq_keep.argtypes = [llama_context_p, llama_seq_id]
-_lib.llama_kv_cache_seq_keep.restype = None
+ ...
# // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -1373,28 +1479,32 @@ _lib.llama_kv_cache_seq_keep.restype = None
# llama_pos p0,
# llama_pos p1,
# llama_pos delta);
+
+
+@ctypes_function(
+ "llama_kv_cache_seq_shift",
+ [
+ llama_context_p_ctypes,
+ llama_seq_id,
+ llama_pos,
+ llama_pos,
+ llama_pos,
+ ],
+ None,
+)
def llama_kv_cache_seq_shift(
ctx: llama_context_p,
seq_id: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
delta: Union[llama_pos, int],
+ /,
):
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
If the KV cache is RoPEd, the KV data is updated accordingly
p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)"""
- return _lib.llama_kv_cache_seq_shift(ctx, seq_id, p0, p1, delta)
-
-
-_lib.llama_kv_cache_seq_shift.argtypes = [
- llama_context_p,
- llama_seq_id,
- llama_pos,
- llama_pos,
- llama_pos,
-]
-_lib.llama_kv_cache_seq_shift.restype = None
+ ...
# // Integer division of the positions by factor of `d > 1`
@@ -1407,29 +1517,34 @@ _lib.llama_kv_cache_seq_shift.restype = None
# llama_pos p0,
# llama_pos p1,
# int d);
+
+
+@ctypes_function(
+ "llama_kv_cache_seq_div",
+ [
+ llama_context_p_ctypes,
+ llama_seq_id,
+ llama_pos,
+ llama_pos,
+ ctypes.c_int,
+ ],
+ None,
+)
def llama_kv_cache_seq_div(
ctx: llama_context_p,
seq_id: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
- d: Union[c_int, int],
+ d: Union[ctypes.c_int, int],
+ /,
):
"""Integer division of the positions by factor of `d > 1`
If the KV cache is RoPEd, the KV data is updated accordingly
p0 < 0 : [0, p1]
p1 < 0 : [p0, inf)"""
- return _lib.llama_kv_cache_seq_div(ctx, seq_id, p0, p1, d)
+ ...
-_lib.llama_kv_cache_seq_div.argtypes = [
- llama_context_p,
- llama_seq_id,
- llama_pos,
- llama_pos,
- c_int,
-]
-_lib.llama_kv_cache_seq_div.restype = None
-
# //
# // State / sessions
# //
@@ -1438,14 +1553,13 @@ _lib.llama_kv_cache_seq_div.restype = None
# Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
-def llama_get_state_size(ctx: llama_context_p) -> int:
+
+
+@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
+def llama_get_state_size(ctx: llama_context_p, /) -> int:
"""Returns the maximum size in bytes of the state (rng, logits, embedding
and kv_cache) - will often be smaller after compacting tokens"""
- return _lib.llama_get_state_size(ctx)
-
-
-_lib.llama_get_state_size.argtypes = [llama_context_p]
-_lib.llama_get_state_size.restype = c_size_t
+ ...
# Copies the state to the specified destination address.
@@ -1454,17 +1568,23 @@ _lib.llama_get_state_size.restype = c_size_t
# LLAMA_API size_t llama_copy_state_data(
# struct llama_context * ctx,
# uint8_t * dst);
+
+
+@ctypes_function(
+ "llama_copy_state_data",
+ [
+ llama_context_p_ctypes,
+ ctypes.POINTER(ctypes.c_uint8),
+ ],
+ ctypes.c_size_t,
+)
def llama_copy_state_data(
- ctx: llama_context_p, dst # type: Array[c_uint8]
+ ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Copies the state to the specified destination address.
Destination needs to have allocated enough memory.
Returns the number of bytes copied"""
- return _lib.llama_copy_state_data(ctx, dst)
-
-
-_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p]
-_lib.llama_copy_state_data.restype = c_size_t
+ ...
# Set the state reading from the specified address
@@ -1472,15 +1592,18 @@ _lib.llama_copy_state_data.restype = c_size_t
# LLAMA_API size_t llama_set_state_data(
# struct llama_context * ctx,
# uint8_t * src);
+
+
+@ctypes_function(
+ "llama_set_state_data",
+ [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
+ ctypes.c_size_t,
+)
def llama_set_state_data(
- ctx: llama_context_p, src # type: Array[c_uint8]
+ ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Set the state reading from the specified address"""
- return _lib.llama_set_state_data(ctx, src)
-
-
-_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p]
-_lib.llama_set_state_data.restype = c_size_t
+ ...
# Save/load session file
@@ -1490,26 +1613,28 @@ _lib.llama_set_state_data.restype = c_size_t
# llama_token * tokens_out,
# size_t n_token_capacity,
# size_t * n_token_count_out);
+
+
+@ctypes_function(
+ "llama_load_session_file",
+ [
+ llama_context_p_ctypes,
+ ctypes.c_char_p,
+ llama_token_p,
+ ctypes.c_size_t,
+ ctypes.POINTER(ctypes.c_size_t),
+ ],
+ ctypes.c_size_t,
+)
def llama_load_session_file(
ctx: llama_context_p,
path_session: bytes,
- tokens_out, # type: Array[llama_token]
- n_token_capacity: Union[c_size_t, int],
- n_token_count_out, # type: _Pointer[c_size_t]
+ tokens_out: CtypesArray[llama_token],
+ n_token_capacity: Union[ctypes.c_size_t, int],
+ n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
+ /,
) -> int:
- return _lib.llama_load_session_file(
- ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
- )
-
-
-_lib.llama_load_session_file.argtypes = [
- llama_context_p,
- c_char_p,
- llama_token_p,
- c_size_t,
- c_size_t_p,
-]
-_lib.llama_load_session_file.restype = c_size_t
+ ...
# LLAMA_API bool llama_save_session_file(
@@ -1517,23 +1642,28 @@ _lib.llama_load_session_file.restype = c_size_t
# const char * path_session,
# const llama_token * tokens,
# size_t n_token_count);
+
+
+@ctypes_function(
+ "llama_save_session_file",
+ [
+ llama_context_p_ctypes,
+ ctypes.c_char_p,
+ llama_token_p,
+ ctypes.c_size_t,
+ ],
+ ctypes.c_size_t,
+)
def llama_save_session_file(
ctx: llama_context_p,
path_session: bytes,
- tokens, # type: Array[llama_token]
- n_token_count: Union[c_size_t, int],
+ tokens: CtypesArray[llama_token],
+ n_token_count: Union[ctypes.c_size_t, int],
+ /,
) -> int:
- return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
+ ...
-_lib.llama_save_session_file.argtypes = [
- llama_context_p,
- c_char_p,
- llama_token_p,
- c_size_t,
-]
-_lib.llama_save_session_file.restype = c_size_t
-
# //
# // Decoding
# //
@@ -1550,22 +1680,31 @@ _lib.llama_save_session_file.restype = c_size_t
# int32_t n_tokens,
# int32_t n_past),
# "use llama_decode() instead");
+
+
+@ctypes_function(
+ "llama_eval",
+ [
+ llama_context_p_ctypes,
+ llama_token_p,
+ ctypes.c_int32,
+ ctypes.c_int32,
+ ],
+ ctypes.c_int,
+)
def llama_eval(
ctx: llama_context_p,
- tokens, # type: Array[llama_token]
- n_tokens: Union[c_int, int],
- n_past: Union[c_int, int],
+ tokens: CtypesArray[llama_token],
+ n_tokens: Union[ctypes.c_int, int],
+ n_past: Union[ctypes.c_int, int],
+ /,
) -> int:
"""Run the llama inference to obtain the logits and probabilities for the next token(s).
tokens + n_tokens is the provided batch of new tokens to process
n_past is the number of tokens to use from previous eval calls
Returns 0 on success
DEPRECATED: use llama_decode() instead"""
- return _lib.llama_eval(ctx, tokens, n_tokens, n_past)
-
-
-_lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int32, c_int32]
-_lib.llama_eval.restype = c_int
+ ...
# // Same as llama_eval, but use float matrix input directly.
@@ -1576,19 +1715,28 @@ _lib.llama_eval.restype = c_int
# int32_t n_tokens,
# int32_t n_past),
# "use llama_decode() instead");
+
+
+@ctypes_function(
+ "llama_eval_embd",
+ [
+ llama_context_p_ctypes,
+ ctypes.POINTER(ctypes.c_float),
+ ctypes.c_int32,
+ ctypes.c_int32,
+ ],
+ ctypes.c_int,
+)
def llama_eval_embd(
ctx: llama_context_p,
- embd, # type: Array[c_float]
- n_tokens: Union[c_int, int],
- n_past: Union[c_int, int],
+ embd: CtypesArray[ctypes.c_float],
+ n_tokens: Union[ctypes.c_int, int],
+ n_past: Union[ctypes.c_int, int],
+ /,
) -> int:
"""Same as llama_eval, but use float matrix input directly.
DEPRECATED: use llama_decode() instead"""
- return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past)
-
-
-_lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int32, c_int32]
-_lib.llama_eval_embd.restype = c_int
+ ...
# // Return batch for single sequence of tokens starting at pos_0
@@ -1600,26 +1748,30 @@ _lib.llama_eval_embd.restype = c_int
# int32_t n_tokens,
# llama_pos pos_0,
# llama_seq_id seq_id);
+
+
+@ctypes_function(
+ "llama_batch_get_one",
+ [
+ llama_token_p,
+ ctypes.c_int,
+ llama_pos,
+ llama_seq_id,
+ ],
+ llama_batch,
+)
def llama_batch_get_one(
- tokens, # type: Array[llama_token]
- n_tokens: Union[c_int, int],
+ tokens: CtypesArray[llama_token],
+ n_tokens: Union[ctypes.c_int, int],
pos_0: Union[llama_pos, int],
seq_id: llama_seq_id,
+ /,
) -> llama_batch:
"""Return batch for single sequence of tokens starting at pos_0
NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
"""
- return _lib.llama_batch_get_one(tokens, n_tokens, pos_0, seq_id)
-
-
-_lib.llama_batch_get_one.argtypes = [
- llama_token_p,
- c_int,
- llama_pos,
- llama_seq_id,
-]
-_lib.llama_batch_get_one.restype = llama_batch
+ ...
# // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
@@ -1633,10 +1785,16 @@ _lib.llama_batch_get_one.restype = llama_batch
# int32_t n_tokens,
# int32_t embd,
# int32_t n_seq_max);
+
+
+@ctypes_function(
+ "llama_batch_init", [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32], llama_batch
+)
def llama_batch_init(
- n_tokens: Union[c_int32, int],
- embd: Union[c_int32, int],
- n_seq_max: Union[c_int32, int],
+ n_tokens: Union[ctypes.c_int32, int],
+ embd: Union[ctypes.c_int32, int],
+ n_seq_max: Union[ctypes.c_int32, int],
+ /,
) -> llama_batch:
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
Each token can be assigned up to n_seq_max sequence ids
@@ -1645,22 +1803,17 @@ def llama_batch_init(
Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
The rest of the llama_batch members are allocated with size n_tokens
All members are left uninitialized"""
- return _lib.llama_batch_init(n_tokens, embd, n_seq_max)
-
-
-_lib.llama_batch_init.argtypes = [c_int32, c_int32, c_int32]
-_lib.llama_batch_init.restype = llama_batch
+ ...
# // Frees a batch of tokens allocated with llama_batch_init()
# LLAMA_API void llama_batch_free(struct llama_batch batch);
-def llama_batch_free(batch: llama_batch):
+
+
+@ctypes_function("llama_batch_free", [llama_batch], None)
+def llama_batch_free(batch: llama_batch, /):
"""Frees a batch of tokens allocated with llama_batch_init()"""
- return _lib.llama_batch_free(batch)
-
-
-_lib.llama_batch_free.argtypes = [llama_batch]
-_lib.llama_batch_free.restype = None
+ ...
# // Positive return values does not mean a fatal error, but rather a warning.
@@ -1670,36 +1823,43 @@ _lib.llama_batch_free.restype = None
# LLAMA_API int32_t llama_decode(
# struct llama_context * ctx,
# struct llama_batch batch);
-def llama_decode(ctx: llama_context_p, batch: llama_batch) -> int:
+
+
+@ctypes_function("llama_decode", [llama_context_p_ctypes, llama_batch], ctypes.c_int32)
+def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int:
"""Positive return values does not mean a fatal error, but rather a warning.
0 - success
1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
< 0 - error"""
- return _lib.llama_decode(ctx, batch)
-
-
-_lib.llama_decode.argtypes = [llama_context_p, llama_batch]
-_lib.llama_decode.restype = c_int32
+ ...
# // Set the number of threads used for decoding
# // n_threads is the number of threads used for generation (single token)
# // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
# LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
+
+
+@ctypes_function(
+ "llama_set_n_threads",
+ [
+ llama_context_p_ctypes,
+ ctypes.c_uint32,
+ ctypes.c_uint32,
+ ],
+ None,
+)
def llama_set_n_threads(
ctx: llama_context_p,
- n_threads: Union[c_uint32, int],
- n_threads_batch: Union[c_uint32, int],
+ n_threads: Union[ctypes.c_uint32, int],
+ n_threads_batch: Union[ctypes.c_uint32, int],
+ /,
):
"""Set the number of threads used for decoding
n_threads is the number of threads used for generation (single token)
n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
"""
- return _lib.llama_set_n_threads(ctx, n_threads, n_threads_batch)
-
-
-_lib.llama_set_n_threads.argtypes = [llama_context_p, c_uint32, c_uint32]
-_lib.llama_set_n_threads.restype = None
+ ...
# // Token logits obtained from the last call to llama_eval()
@@ -1708,64 +1868,68 @@ _lib.llama_set_n_threads.restype = None
# // Rows: n_tokens provided with llama_batch
# // Cols: n_vocab
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
-def llama_get_logits(
- ctx: llama_context_p,
-): # type: (...) -> Array[float] # type: ignore
+
+
+@ctypes_function(
+ "llama_get_logits", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
+)
+def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
"""Token logits obtained from the last call to llama_eval()
The logits for the last token are stored in the last row
Logits for which llama_batch.logits[i] == 0 are undefined
Rows: n_tokens provided with llama_batch
Cols: n_vocab"""
- return _lib.llama_get_logits(ctx)
-
-
-_lib.llama_get_logits.argtypes = [llama_context_p]
-_lib.llama_get_logits.restype = c_float_p
+ ...
# // Logits for the ith token. Equivalent to:
# // llama_get_logits(ctx) + i*n_vocab
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
+
+
+@ctypes_function(
+ "llama_get_logits_ith",
+ [llama_context_p_ctypes, ctypes.c_int32],
+ ctypes.POINTER(ctypes.c_float),
+)
def llama_get_logits_ith(
- ctx: llama_context_p, i: Union[c_int32, int]
-): # type: (...) -> Array[float] # type: ignore
+ ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
+) -> CtypesArray[ctypes.c_float]:
"""Logits for the ith token. Equivalent to:
llama_get_logits(ctx) + i*n_vocab"""
- return _lib.llama_get_logits_ith(ctx, i)
-
-
-_lib.llama_get_logits_ith.argtypes = [llama_context_p, c_int32]
-_lib.llama_get_logits_ith.restype = c_float_p
+ ...
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
# LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
-def llama_get_embeddings(
- ctx: llama_context_p,
-): # type: (...) -> Array[float] # type: ignore
+
+
+@ctypes_function(
+ "llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float)
+)
+def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
"""Get the embeddings for the input
shape: [n_embd] (1-dimensional)"""
- return _lib.llama_get_embeddings(ctx)
-
-
-_lib.llama_get_embeddings.argtypes = [llama_context_p]
-_lib.llama_get_embeddings.restype = c_float_p
+ ...
# // Get the embeddings for the ith sequence
# // llama_get_embeddings(ctx) + i*n_embd
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
+
+
+@ctypes_function(
+ "llama_get_embeddings_ith",
+ [llama_context_p_ctypes, ctypes.c_int32],
+ ctypes.POINTER(ctypes.c_float),
+)
def llama_get_embeddings_ith(
- ctx: llama_context_p, i: Union[c_int32, int]
-): # type: (...) -> Array[float] # type: ignore
+ ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
+) -> CtypesArray[ctypes.c_float]:
"""Get the embeddings for the ith sequence
llama_get_embeddings(ctx) + i*n_embd"""
- return _lib.llama_get_embeddings_ith(ctx, i)
-
-
-_lib.llama_get_embeddings_ith.argtypes = [llama_context_p, c_int32]
-_lib.llama_get_embeddings_ith.restype = c_float_p
+ ...
# //
@@ -1774,125 +1938,123 @@ _lib.llama_get_embeddings_ith.restype = c_float_p
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
-def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int]) -> bytes:
- return _lib.llama_token_get_text(model, token)
-_lib.llama_token_get_text.argtypes = [llama_model_p, llama_token]
-_lib.llama_token_get_text.restype = c_char_p
+@ctypes_function(
+ "llama_token_get_text", [llama_model_p_ctypes, llama_token], ctypes.c_char_p
+)
+def llama_token_get_text(
+ model: llama_model_p, token: Union[llama_token, int], /
+) -> bytes:
+ ...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
+
+
+@ctypes_function(
+ "llama_token_get_score", [llama_model_p_ctypes, llama_token], ctypes.c_float
+)
def llama_token_get_score(
- model: llama_model_p, token: Union[llama_token, int]
+ model: llama_model_p, token: Union[llama_token, int], /
) -> float:
- return _lib.llama_token_get_score(model, token)
-
-
-_lib.llama_token_get_score.argtypes = [llama_model_p, llama_token]
-_lib.llama_token_get_score.restype = c_float
+ ...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
-def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int]) -> int:
- return _lib.llama_token_get_type(model, token)
-_lib.llama_token_get_type.argtypes = [llama_model_p, llama_token]
-_lib.llama_token_get_type.restype = ctypes.c_int
+@ctypes_function(
+ "llama_token_get_type", [llama_model_p_ctypes, llama_token], ctypes.c_int
+)
+def llama_token_get_type(
+ model: llama_model_p, token: Union[llama_token, int], /
+) -> int:
+ ...
# // Special tokens
# LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
-def llama_token_bos(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_token_bos", [llama_model_p_ctypes], llama_token)
+def llama_token_bos(model: llama_model_p, /) -> int:
"""beginning-of-sentence"""
- return _lib.llama_token_bos(model)
-
-
-_lib.llama_token_bos.argtypes = [llama_model_p]
-_lib.llama_token_bos.restype = llama_token
+ ...
# LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
-def llama_token_eos(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_token_eos", [llama_model_p_ctypes], llama_token)
+def llama_token_eos(model: llama_model_p, /) -> int:
"""end-of-sentence"""
- return _lib.llama_token_eos(model)
-
-
-_lib.llama_token_eos.argtypes = [llama_model_p]
-_lib.llama_token_eos.restype = llama_token
+ ...
# LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
-def llama_token_nl(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token)
+def llama_token_nl(model: llama_model_p, /) -> int:
"""next-line"""
- return _lib.llama_token_nl(model)
-
-
-_lib.llama_token_nl.argtypes = [llama_model_p]
-_lib.llama_token_nl.restype = llama_token
+ ...
# // Returns -1 if unknown, 1 for true or 0 for false.
# LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
-def llama_add_bos_token(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_add_bos_token", [llama_model_p_ctypes], ctypes.c_int32)
+def llama_add_bos_token(model: llama_model_p, /) -> int:
"""Returns -1 if unknown, 1 for true or 0 for false."""
- return _lib.llama_add_bos_token(model)
-
-
-_lib.llama_add_bos_token.argtypes = [llama_model_p]
-_lib.llama_add_bos_token.restype = c_int32
+ ...
# // Returns -1 if unknown, 1 for true or 0 for false.
# LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
-def llama_add_eos_token(model: llama_model_p) -> int:
+
+
+@ctypes_function("llama_add_eos_token", [llama_model_p_ctypes], ctypes.c_int32)
+def llama_add_eos_token(model: llama_model_p, /) -> int:
"""Returns -1 if unknown, 1 for true or 0 for false."""
- return _lib.llama_add_eos_token(model)
-
-
-_lib.llama_add_eos_token.argtypes = [llama_model_p]
-_lib.llama_add_eos_token.restype = c_int32
+ ...
# // codellama infill tokens
# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
+
+
+@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token)
def llama_token_prefix(model: llama_model_p) -> int:
"""codellama infill tokens"""
- return _lib.llama_token_prefix(model)
-
-
-_lib.llama_token_prefix.argtypes = [llama_model_p]
-_lib.llama_token_prefix.restype = llama_token
+ ...
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
-def llama_token_middle(model: llama_model_p) -> int:
- return _lib.llama_token_middle(model)
-_lib.llama_token_middle.argtypes = [llama_model_p]
-_lib.llama_token_middle.restype = llama_token
+@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
+def llama_token_middle(model: llama_model_p, /) -> int:
+ ...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
-def llama_token_suffix(model: llama_model_p) -> int:
- return _lib.llama_token_suffix(model)
-_lib.llama_token_suffix.argtypes = [llama_model_p]
-_lib.llama_token_suffix.restype = llama_token
+@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
+def llama_token_suffix(model: llama_model_p, /) -> int:
+ ...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
-def llama_token_eot(model: llama_model_p) -> int:
- return _lib.llama_token_eot(model)
-_lib.llama_token_eot.argtypes = [llama_model_p]
-_lib.llama_token_eot.restype = llama_token
+@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
+def llama_token_eot(model: llama_model_p, /) -> int:
+ ...
# //
@@ -1914,31 +2076,33 @@ _lib.llama_token_eot.restype = llama_token
# int32_t n_max_tokens,
# bool add_bos,
# bool special);
+
+
+@ctypes_function(
+ "llama_tokenize",
+ [
+ llama_model_p_ctypes,
+ ctypes.c_char_p,
+ ctypes.c_int32,
+ llama_token_p,
+ ctypes.c_int32,
+ ctypes.c_bool,
+ ctypes.c_bool,
+ ],
+ ctypes.c_int32,
+)
def llama_tokenize(
model: llama_model_p,
text: bytes,
- text_len: Union[c_int, int],
- tokens, # type: Array[llama_token]
- n_max_tokens: Union[c_int, int],
- add_bos: Union[c_bool, bool],
- special: Union[c_bool, bool],
+ text_len: Union[ctypes.c_int, int],
+ tokens: CtypesArray[llama_token],
+ n_max_tokens: Union[ctypes.c_int, int],
+ add_bos: Union[ctypes.c_bool, bool],
+ special: Union[ctypes.c_bool, bool],
+ /,
) -> int:
"""Convert the provided text into tokens."""
- return _lib.llama_tokenize(
- model, text, text_len, tokens, n_max_tokens, add_bos, special
- )
-
-
-_lib.llama_tokenize.argtypes = [
- llama_model_p,
- c_char_p,
- c_int32,
- llama_token_p,
- c_int32,
- c_bool,
- c_bool,
-]
-_lib.llama_tokenize.restype = c_int32
+ ...
# // Token Id -> Piece.
@@ -1950,27 +2114,36 @@ _lib.llama_tokenize.restype = c_int32
# llama_token token,
# char * buf,
# int32_t length);
+
+
+@ctypes_function(
+ "llama_token_to_piece",
+ [
+ llama_model_p_ctypes,
+ llama_token,
+ ctypes.c_char_p,
+ ctypes.c_int32,
+ ],
+ ctypes.c_int32,
+)
def llama_token_to_piece(
model: llama_model_p,
token: Union[llama_token, int],
- buf: Union[c_char_p, bytes],
- length: Union[c_int, int],
+ buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]],
+ length: Union[ctypes.c_int, int],
+ /,
) -> int:
"""Token Id -> Piece.
Uses the vocabulary in the provided context.
Does not write null terminator to the buffer.
User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
"""
- return _lib.llama_token_to_piece(model, token, buf, length)
-
-
-_lib.llama_token_to_piece.argtypes = [llama_model_p, llama_token, c_char_p, c_int32]
-_lib.llama_token_to_piece.restype = c_int32
+ ...
# /// Apply chat template. Inspired by hf apply_chat_template() on python.
# /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
-# /// NOTE: This function only support some known jinja templates. It is not a jinja parser.
+# /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
# /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
# /// @param chat Pointer to a list of multiple llama_chat_message
# /// @param n_msg Number of llama_chat_message in this chat
@@ -1986,27 +2159,26 @@ _lib.llama_token_to_piece.restype = c_int32
# bool add_ass,
# char * buf,
# int32_t length);
+
+
+@ctypes_function(
+ "llama_chat_apply_template",
+ [
+ ctypes.c_void_p,
+ ctypes.c_char_p,
+ ctypes.POINTER(llama_chat_message),
+ ctypes.c_size_t,
+ ],
+ ctypes.c_int32,
+)
def llama_chat_apply_template(
- model: llama_model_p,
- tmpl: bytes,
- chat: "ctypes._Pointer[llama_chat_message]",
- n_msg: int,
+ model: llama_model_p,
+ tmpl: bytes,
+ chat: CtypesArray[llama_chat_message],
+ n_msg: int,
+ /,
) -> int:
- return _lib.llama_chat_apply_template(
- model,
- tmpl,
- chat,
- n_msg
- )
-
-_lib.llama_chat_apply_template.argtypes = [
- ctypes.c_void_p,
- ctypes.c_char_p,
- ctypes.POINTER(llama_chat_message),
- ctypes.c_size_t
-]
-_lib.llama_chat_apply_template.restype = ctypes.c_int32
-
+ ...
# //
@@ -2018,42 +2190,51 @@ _lib.llama_chat_apply_template.restype = ctypes.c_int32
# const llama_grammar_element ** rules,
# size_t n_rules,
# size_t start_rule_index);
+
+
+@ctypes_function(
+ "llama_grammar_init",
+ [
+ ctypes.POINTER(llama_grammar_element_p),
+ ctypes.c_size_t,
+ ctypes.c_size_t,
+ ],
+ llama_grammar_p,
+)
def llama_grammar_init(
- rules, # type: Array[llama_grammar_element_p] # type: ignore
- n_rules: Union[c_size_t, int],
- start_rule_index: Union[c_size_t, int],
+ rules: CtypesArray[
+ CtypesPointer[llama_grammar_element]
+ ], # NOTE: This might be wrong type sig
+ n_rules: Union[ctypes.c_size_t, int],
+ start_rule_index: Union[ctypes.c_size_t, int],
+ /,
) -> llama_grammar_p:
"""Initialize a grammar from a set of rules."""
- return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
-
-
-_lib.llama_grammar_init.argtypes = [
- POINTER(llama_grammar_element_p),
- c_size_t,
- c_size_t,
-]
-_lib.llama_grammar_init.restype = llama_grammar_p
+ ...
# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
-def llama_grammar_free(grammar: llama_grammar_p):
+@ctypes_function(
+ "llama_grammar_free",
+ [llama_grammar_p],
+ None,
+)
+def llama_grammar_free(grammar: llama_grammar_p, /):
"""Free a grammar."""
- return _lib.llama_grammar_free(grammar)
-
-
-_lib.llama_grammar_free.argtypes = [llama_grammar_p]
-_lib.llama_grammar_free.restype = None
+ ...
# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
-def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p:
+@ctypes_function(
+ "llama_grammar_copy",
+ [llama_grammar_p],
+ llama_grammar_p,
+)
+def llama_grammar_copy(grammar: llama_grammar_p, /) -> llama_grammar_p:
"""Copy a grammar."""
- return _lib.llama_grammar_copy(grammar)
+ ...
-_lib.llama_grammar_copy.argtypes = [llama_grammar_p]
-_lib.llama_grammar_copy.restype = llama_grammar_p
-
# //
# // Sampling functions
# //
@@ -2061,13 +2242,14 @@ _lib.llama_grammar_copy.restype = llama_grammar_p
# // Sets the current rng seed.
# LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
-def llama_set_rng_seed(ctx: llama_context_p, seed: Union[c_uint32, int]):
+@ctypes_function(
+ "llama_set_rng_seed",
+ [llama_context_p_ctypes, ctypes.c_uint32],
+ None,
+)
+def llama_set_rng_seed(ctx: llama_context_p, seed: Union[ctypes.c_uint32, int], /):
"""Sets the current rng seed."""
- return _lib.llama_set_rng_seed(ctx, seed)
-
-
-_lib.llama_set_rng_seed.argtypes = [llama_context_p, c_uint32]
-_lib.llama_set_rng_seed.restype = None
+ ...
# /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@@ -2080,39 +2262,35 @@ _lib.llama_set_rng_seed.restype = None
# float penalty_repeat,
# float penalty_freq,
# float penalty_present);
+@ctypes_function(
+ "llama_sample_repetition_penalties",
+ [
+ llama_context_p_ctypes,
+ llama_token_data_array_p,
+ llama_token_p,
+ ctypes.c_size_t,
+ ctypes.c_float,
+ ctypes.c_float,
+ ctypes.c_float,
+ ],
+ None,
+)
def llama_sample_repetition_penalties(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- last_tokens_data, # type: Array[llama_token]
- penalty_last_n: Union[c_size_t, int],
- penalty_repeat: Union[c_float, float],
- penalty_freq: Union[c_float, float],
- penalty_present: Union[c_float, float],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ last_tokens_data: CtypesArray[llama_token],
+ penalty_last_n: Union[ctypes.c_size_t, int],
+ penalty_repeat: Union[ctypes.c_float, float],
+ penalty_freq: Union[ctypes.c_float, float],
+ penalty_present: Union[ctypes.c_float, float],
+ /,
):
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
"""
- return _lib.llama_sample_repetition_penalties(
- ctx,
- candidates,
- last_tokens_data,
- penalty_last_n,
- penalty_repeat,
- penalty_freq,
- penalty_present,
- )
-
-
-_lib.llama_sample_repetition_penalties.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- llama_token_p,
- c_size_t,
- c_float,
- c_float,
- c_float,
-]
-_lib.llama_sample_repetition_penalties.restype = None
+ ...
# /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
@@ -2124,23 +2302,25 @@ _lib.llama_sample_repetition_penalties.restype = None
# float * logits,
# float * logits_guidance,
# float scale);
+@ctypes_function(
+ "llama_sample_apply_guidance",
+ [
+ llama_context_p_ctypes,
+ ctypes.POINTER(ctypes.c_float),
+ ctypes.POINTER(ctypes.c_float),
+ ctypes.c_float,
+ ],
+ None,
+)
def llama_sample_apply_guidance(
ctx: llama_context_p,
- logits, # type: _Pointer[c_float]
- logits_guidance, # type: _Pointer[c_float]
- scale: Union[c_float, float],
+ logits: CtypesArray[ctypes.c_float],
+ logits_guidance: CtypesArray[ctypes.c_float],
+ scale: Union[ctypes.c_float, float],
+ /,
):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
- return _lib.llama_sample_apply_guidance(ctx, logits, logits_guidance, scale)
-
-
-_lib.llama_sample_apply_guidance.argtypes = [
- llama_context_p,
- c_float_p,
- c_float_p,
- c_float,
-]
-_lib.llama_sample_apply_guidance.restype = None
+ ...
# LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
@@ -2149,43 +2329,47 @@ _lib.llama_sample_apply_guidance.restype = None
# struct llama_context * guidance_ctx,
# float scale),
# "use llama_sample_apply_guidance() instead");
+@ctypes_function(
+ "llama_sample_classifier_free_guidance",
+ [
+ llama_context_p_ctypes,
+ llama_token_data_array_p,
+ llama_context_p_ctypes,
+ ctypes.c_float,
+ ],
+ None,
+)
def llama_sample_classifier_free_guidance(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
guidance_ctx: llama_context_p,
- scale: Union[c_float, float],
+ scale: Union[ctypes.c_float, float],
+ /,
):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
- return _lib.llama_sample_classifier_free_guidance(
- ctx, candidates, guidance_ctx, scale
- )
-
-
-_lib.llama_sample_classifier_free_guidance.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- llama_context_p,
- c_float,
-]
-_lib.llama_sample_classifier_free_guidance.restype = None
+ ...
# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
# LLAMA_API void llama_sample_softmax(
# struct llama_context * ctx,
# llama_token_data_array * candidates);
+@ctypes_function(
+ "llama_sample_softmax",
+ [llama_context_p_ctypes, llama_token_data_array_p],
+ None,
+)
def llama_sample_softmax(
- ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
+ ctx: llama_context_p,
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ /,
):
"""Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits."""
- return _lib.llama_sample_softmax(ctx, candidates)
-
-
-_lib.llama_sample_softmax.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
-]
-_lib.llama_sample_softmax.restype = None
+ ...
# /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
@@ -2194,23 +2378,22 @@ _lib.llama_sample_softmax.restype = None
# llama_token_data_array * candidates,
# int32_t k,
# size_t min_keep);
+@ctypes_function(
+ "llama_sample_top_k",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_int32, ctypes.c_size_t],
+ None,
+)
def llama_sample_top_k(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- k: Union[c_int, int],
- min_keep: Union[c_size_t, int],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ k: Union[ctypes.c_int, int],
+ min_keep: Union[ctypes.c_size_t, int],
+ /,
):
"""Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
- return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
-
-
-_lib.llama_sample_top_k.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_int32,
- c_size_t,
-]
-_lib.llama_sample_top_k.restype = None
+ ...
# /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
@@ -2219,23 +2402,22 @@ _lib.llama_sample_top_k.restype = None
# llama_token_data_array * candidates,
# float p,
# size_t min_keep);
+@ctypes_function(
+ "llama_sample_top_p",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t],
+ None,
+)
def llama_sample_top_p(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- p: Union[c_float, float],
- min_keep: Union[c_size_t, int],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ p: Union[ctypes.c_float, float],
+ min_keep: Union[ctypes.c_size_t, int],
+ /,
):
"""Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
- return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
-
-
-_lib.llama_sample_top_p.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_size_t,
-]
-_lib.llama_sample_top_p.restype = None
+ ...
# /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
@@ -2244,23 +2426,22 @@ _lib.llama_sample_top_p.restype = None
# llama_token_data_array * candidates,
# float p,
# size_t min_keep);
+@ctypes_function(
+ "llama_sample_min_p",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t],
+ None,
+)
def llama_sample_min_p(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- p: Union[c_float, float],
- min_keep: Union[c_size_t, int],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ p: Union[ctypes.c_float, float],
+ min_keep: Union[ctypes.c_size_t, int],
+ /,
):
"""Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"""
- return _lib.llama_sample_min_p(ctx, candidates, p, min_keep)
-
-
-_lib.llama_sample_min_p.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_size_t,
-]
-_lib.llama_sample_min_p.restype = None
+ ...
# /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
@@ -2269,23 +2450,22 @@ _lib.llama_sample_min_p.restype = None
# llama_token_data_array * candidates,
# float z,
# size_t min_keep);
+@ctypes_function(
+ "llama_sample_tail_free",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t],
+ None,
+)
def llama_sample_tail_free(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- z: Union[c_float, float],
- min_keep: Union[c_size_t, int],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ z: Union[ctypes.c_float, float],
+ min_keep: Union[ctypes.c_size_t, int],
+ /,
):
"""Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/."""
- return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
-
-
-_lib.llama_sample_tail_free.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_size_t,
-]
-_lib.llama_sample_tail_free.restype = None
+ ...
# /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
@@ -2294,23 +2474,22 @@ _lib.llama_sample_tail_free.restype = None
# llama_token_data_array * candidates,
# float p,
# size_t min_keep);
+@ctypes_function(
+ "llama_sample_typical",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t],
+ None,
+)
def llama_sample_typical(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- p: Union[c_float, float],
- min_keep: Union[c_size_t, int],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ p: Union[ctypes.c_float, float],
+ min_keep: Union[ctypes.c_size_t, int],
+ /,
):
"""Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666."""
- return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
-
-
-_lib.llama_sample_typical.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_size_t,
-]
-_lib.llama_sample_typical.restype = None
+ ...
# /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
@@ -2320,35 +2499,47 @@ _lib.llama_sample_typical.restype = None
# float min_temp,
# float max_temp,
# float exponent_val);
+@ctypes_function(
+ "llama_sample_entropy",
+ [
+ llama_context_p_ctypes,
+ llama_token_data_array_p,
+ ctypes.c_float,
+ ctypes.c_float,
+ ctypes.c_float,
+ ],
+ None,
+)
def llama_sample_entropy(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- min_temp: Union[c_float, float],
- max_temp: Union[c_float, float],
- exponent_val: Union[c_float, float],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ min_temp: Union[ctypes.c_float, float],
+ max_temp: Union[ctypes.c_float, float],
+ exponent_val: Union[ctypes.c_float, float],
+ /,
):
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
- return _lib.llama_sample_entropy(ctx, candidates, min_temp, max_temp, exponent_val)
-
-
-_lib.llama_sample_entropy.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_float,
- c_float,
-]
-_lib.llama_sample_entropy.restype = None
+ ...
# LLAMA_API void llama_sample_temp(
# struct llama_context * ctx,
# llama_token_data_array * candidates,
# float temp);
+@ctypes_function(
+ "llama_sample_temp",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float],
+ None,
+)
def llama_sample_temp(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- temp: Union[c_float, float],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ temp: Union[ctypes.c_float, float],
+ /,
):
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
@@ -2356,15 +2547,7 @@ def llama_sample_temp(
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
temp: The temperature value to use for the sampling. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
"""
- return _lib.llama_sample_temp(ctx, candidates, temp)
-
-
-_lib.llama_sample_temp.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
-]
-_lib.llama_sample_temp.restype = None
+ ...
# LLAMA_API DEPRECATED(void llama_sample_temperature(
@@ -2372,21 +2555,21 @@ _lib.llama_sample_temp.restype = None
# llama_token_data_array * candidates,
# float temp),
# "use llama_sample_temp instead");
+@ctypes_function(
+ "llama_sample_temperature",
+ [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float],
+ None,
+)
def llama_sample_temperature(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- temp: Union[c_float, float],
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ temp: Union[ctypes.c_float, float],
+ /,
):
"""use llama_sample_temp instead"""
- return _lib.llama_sample_temperature(ctx, candidates, temp)
-
-
-_lib.llama_sample_temperature.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
-]
-_lib.llama_sample_temperature.restype = None
+ ...
# /// @details Apply constraints from grammar
@@ -2394,10 +2577,18 @@ _lib.llama_sample_temperature.restype = None
# struct llama_context * ctx,
# llama_token_data_array * candidates,
# const struct llama_grammar * grammar);
+@ctypes_function(
+ "llama_sample_grammar",
+ [llama_context_p_ctypes, llama_token_data_array_p, llama_grammar_p],
+ None,
+)
def llama_sample_grammar(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
grammar, # type: llama_grammar_p
+ /,
):
"""Apply constraints from grammar
@@ -2405,15 +2596,7 @@ def llama_sample_grammar(
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
grammar: A grammar object containing the rules and constraints to apply to the generated text.
"""
- return _lib.llama_sample_grammar(ctx, candidates, grammar)
-
-
-_lib.llama_sample_grammar.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- llama_grammar_p,
-]
-_lib.llama_sample_grammar.restype = None
+ ...
# /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -2429,13 +2612,28 @@ _lib.llama_sample_grammar.restype = None
# float eta,
# int32_t m,
# float * mu);
+@ctypes_function(
+ "llama_sample_token_mirostat",
+ [
+ llama_context_p_ctypes,
+ llama_token_data_array_p,
+ ctypes.c_float,
+ ctypes.c_float,
+ ctypes.c_int32,
+ ctypes.POINTER(ctypes.c_float),
+ ],
+ llama_token,
+)
def llama_sample_token_mirostat(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- tau: Union[c_float, float],
- eta: Union[c_float, float],
- m: Union[c_int, int],
- mu, # type: _Pointer[c_float]
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ tau: Union[ctypes.c_float, float],
+ eta: Union[ctypes.c_float, float],
+ m: Union[ctypes.c_int, int],
+ mu: CtypesPointerOrRef[ctypes.c_float],
+ /,
) -> int:
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -2446,18 +2644,7 @@ def llama_sample_token_mirostat(
m: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
"""
- return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
-
-
-_lib.llama_sample_token_mirostat.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_float,
- c_int32,
- c_float_p,
-]
-_lib.llama_sample_token_mirostat.restype = llama_token
+ ...
# /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -2471,12 +2658,26 @@ _lib.llama_sample_token_mirostat.restype = llama_token
# float tau,
# float eta,
# float * mu);
+@ctypes_function(
+ "llama_sample_token_mirostat_v2",
+ [
+ llama_context_p_ctypes,
+ llama_token_data_array_p,
+ ctypes.c_float,
+ ctypes.c_float,
+ ctypes.POINTER(ctypes.c_float),
+ ],
+ llama_token,
+)
def llama_sample_token_mirostat_v2(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
- tau: Union[c_float, float],
- eta: Union[c_float, float],
- mu, # type: _Pointer[c_float]
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ tau: Union[ctypes.c_float, float],
+ eta: Union[ctypes.c_float, float],
+ mu: CtypesPointerOrRef[ctypes.c_float],
+ /,
) -> int:
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -2486,17 +2687,7 @@ def llama_sample_token_mirostat_v2(
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
"""
- return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
-
-
-_lib.llama_sample_token_mirostat_v2.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
- c_float,
- c_float,
- c_float_p,
-]
-_lib.llama_sample_token_mirostat_v2.restype = llama_token
+ ...
# /// @details Selects the token with the highest probability.
@@ -2504,38 +2695,40 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
# LLAMA_API llama_token llama_sample_token_greedy(
# struct llama_context * ctx,
# llama_token_data_array * candidates);
+@ctypes_function(
+ "llama_sample_token_greedy",
+ [llama_context_p_ctypes, llama_token_data_array_p],
+ llama_token,
+)
def llama_sample_token_greedy(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ /,
) -> int:
"""Selects the token with the highest probability."""
- return _lib.llama_sample_token_greedy(ctx, candidates)
-
-
-_lib.llama_sample_token_greedy.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
-]
-_lib.llama_sample_token_greedy.restype = llama_token
+ ...
# /// @details Randomly selects a token from the candidates based on their probabilities.
# LLAMA_API llama_token llama_sample_token(
# struct llama_context * ctx,
# llama_token_data_array * candidates);
+@ctypes_function(
+ "llama_sample_token",
+ [llama_context_p_ctypes, llama_token_data_array_p],
+ llama_token,
+)
def llama_sample_token(
ctx: llama_context_p,
- candidates, # type: _Pointer[llama_token_data_array]
+ candidates: Union[
+ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
+ ],
+ /,
) -> int:
"""Randomly selects a token from the candidates based on their probabilities."""
- return _lib.llama_sample_token(ctx, candidates)
-
-
-_lib.llama_sample_token.argtypes = [
- llama_context_p,
- llama_token_data_array_p,
-]
-_lib.llama_sample_token.restype = llama_token
+ ...
# /// @details Accepts the sampled token into the grammar
@@ -2543,21 +2736,16 @@ _lib.llama_sample_token.restype = llama_token
# struct llama_context * ctx,
# struct llama_grammar * grammar,
# llama_token token);
+@ctypes_function(
+ "llama_grammar_accept_token",
+ [llama_context_p_ctypes, llama_grammar_p, llama_token],
+ None,
+)
def llama_grammar_accept_token(
- ctx: llama_context_p,
- grammar: llama_grammar_p,
- token: Union[llama_token, int],
+ ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], /
) -> None:
"""Accepts the sampled token into the grammar"""
- _lib.llama_grammar_accept_token(ctx, grammar, token)
-
-
-_lib.llama_grammar_accept_token.argtypes = [
- llama_context_p,
- llama_grammar_p,
- llama_token,
-]
-_lib.llama_grammar_accept_token.restype = None
+ ...
# //
@@ -2575,9 +2763,9 @@ _lib.llama_grammar_accept_token.restype = None
class llama_beam_view(ctypes.Structure):
_fields_ = [
("tokens", llama_token_p),
- ("n_tokens", c_size_t),
- ("p", c_float),
- ("eob", c_bool),
+ ("n_tokens", ctypes.c_size_t),
+ ("p", ctypes.c_float),
+ ("eob", ctypes.c_bool),
]
@@ -2593,10 +2781,10 @@ class llama_beam_view(ctypes.Structure):
# };
class llama_beams_state(ctypes.Structure):
_fields_ = [
- ("beam_views", POINTER(llama_beam_view)),
- ("n_beams", c_size_t),
- ("common_prefix_length", c_size_t),
- ("last_call", c_bool),
+ ("beam_views", ctypes.POINTER(llama_beam_view)),
+ ("n_beams", ctypes.c_size_t),
+ ("common_prefix_length", ctypes.c_size_t),
+ ("last_call", ctypes.c_bool),
]
@@ -2604,7 +2792,9 @@ class llama_beams_state(ctypes.Structure):
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
-llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state)
+llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(
+ None, ctypes.c_void_p, llama_beams_state
+)
# /// @details Deterministically returns entire sentence constructed by a beam search.
@@ -2622,95 +2812,103 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_s
# size_t n_beams,
# int32_t n_past,
# int32_t n_predict);
+@ctypes_function(
+ "llama_beam_search",
+ [
+ llama_context_p_ctypes,
+ llama_beam_search_callback_fn_t,
+ ctypes.c_void_p,
+ ctypes.c_size_t,
+ ctypes.c_int32,
+ ctypes.c_int32,
+ ],
+ None,
+)
def llama_beam_search(
ctx: llama_context_p,
- callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
- callback_data: c_void_p,
- n_beams: Union[c_size_t, int],
- n_past: Union[c_int, int],
- n_predict: Union[c_int, int],
+ callback: CtypesFuncPointer,
+ callback_data: ctypes.c_void_p,
+ n_beams: Union[ctypes.c_size_t, int],
+ n_past: Union[ctypes.c_int, int],
+ n_predict: Union[ctypes.c_int, int],
+ /,
):
- return _lib.llama_beam_search(
- ctx, callback, callback_data, n_beams, n_past, n_predict
- )
-
-
-_lib.llama_beam_search.argtypes = [
- llama_context_p,
- llama_beam_search_callback_fn_t,
- c_void_p,
- c_size_t,
- c_int32,
- c_int32,
-]
-_lib.llama_beam_search.restype = None
+ ...
# Performance information
# LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
-def llama_get_timings(ctx: llama_context_p) -> llama_timings:
+@ctypes_function(
+ "llama_get_timings",
+ [llama_context_p_ctypes],
+ llama_timings,
+)
+def llama_get_timings(ctx: llama_context_p, /) -> llama_timings:
"""Get performance information"""
- return _lib.llama_get_timings(ctx)
-
-
-_lib.llama_get_timings.argtypes = [llama_context_p]
-_lib.llama_get_timings.restype = llama_timings
+ ...
# LLAMA_API void llama_print_timings(struct llama_context * ctx);
-def llama_print_timings(ctx: llama_context_p):
+@ctypes_function(
+ "llama_print_timings",
+ [llama_context_p_ctypes],
+ None,
+)
+def llama_print_timings(ctx: llama_context_p, /):
"""Print performance information"""
- _lib.llama_print_timings(ctx)
-
-
-_lib.llama_print_timings.argtypes = [llama_context_p]
-_lib.llama_print_timings.restype = None
+ ...
# LLAMA_API void llama_reset_timings(struct llama_context * ctx);
-def llama_reset_timings(ctx: llama_context_p):
+@ctypes_function(
+ "llama_reset_timings",
+ [llama_context_p_ctypes],
+ None,
+)
+def llama_reset_timings(ctx: llama_context_p, /):
"""Reset performance information"""
- _lib.llama_reset_timings(ctx)
-
-
-_lib.llama_reset_timings.argtypes = [llama_context_p]
-_lib.llama_reset_timings.restype = None
+ ...
# Print system information
# LLAMA_API const char * llama_print_system_info(void);
+@ctypes_function(
+ "llama_print_system_info",
+ [],
+ ctypes.c_char_p,
+)
def llama_print_system_info() -> bytes:
"""Print system information"""
- return _lib.llama_print_system_info()
-
-
-_lib.llama_print_system_info.argtypes = []
-_lib.llama_print_system_info.restype = c_char_p
+ ...
# NOTE: THIS IS CURRENTLY BROKEN AS ggml_log_callback IS NOT EXPOSED IN LLAMA.H
# // Set callback for all future logging events.
# // If this is not called, or NULL is supplied, everything is output on stderr.
# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
+@ctypes_function(
+ "llama_log_set",
+ [ctypes.c_void_p, ctypes.c_void_p],
+ None,
+)
def llama_log_set(
- log_callback: Union["ctypes._FuncPointer", c_void_p], user_data: c_void_p # type: ignore
+ log_callback: Optional[CtypesFuncPointer],
+ user_data: ctypes.c_void_p,
+ /,
):
"""Set callback for all future logging events.
If this is not called, or NULL is supplied, everything is output on stderr."""
- return _lib.llama_log_set(log_callback, user_data)
-
-
-_lib.llama_log_set.argtypes = [ctypes.c_void_p, c_void_p]
-_lib.llama_log_set.restype = None
+ ...
# LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
-def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p):
- return _lib.llama_dump_timing_info_yaml(stream, ctx)
-
-
-_lib.llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p]
-_lib.llama_dump_timing_info_yaml.restype = None
+@ctypes_function(
+ "llama_dump_timing_info_yaml",
+ [ctypes.c_void_p, llama_context_p_ctypes],
+ None,
+)
+def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
+ ...
diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 3eb3b96..6a37857 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -1498,9 +1498,21 @@ class SchemaConverter:
item_rule_name = self.visit(
schema["items"], f'{name}{"-" if name else ""}item'
)
- rule = (
- f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
- )
+ list_item_operator = f'("," space {item_rule_name})'
+ successive_items = ""
+ min_items = schema.get("minItems", 0)
+ if min_items > 0:
+ first_item = f"({item_rule_name})"
+ successive_items = list_item_operator * (min_items - 1)
+ min_items -= 1
+ else:
+ first_item = f"({item_rule_name})?"
+ max_items = schema.get("maxItems")
+ if max_items is not None and max_items > min_items:
+ successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
+ else:
+ successive_items += list_item_operator + "*"
+ rule = f'"[" space {first_item} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
else:
diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py
index 8195bd4..4eaa9e5 100644
--- a/llama_cpp/llava_cpp.py
+++ b/llama_cpp/llava_cpp.py
@@ -5,21 +5,15 @@ from ctypes import (
c_bool,
c_char_p,
c_int,
- c_int8,
- c_int32,
c_uint8,
- c_uint32,
- c_size_t,
c_float,
- c_double,
c_void_p,
POINTER,
_Pointer, # type: ignore
Structure,
- Array,
)
import pathlib
-from typing import List, Union
+from typing import List, Union, NewType, Optional
import llama_cpp.llama_cpp as llama_cpp
@@ -67,7 +61,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
- return ctypes.CDLL(str(_lib_path), **cdll_args)
+ return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@@ -88,7 +82,8 @@ _libllava = _load_shared_library(_libllava_base_name)
################################################
# struct clip_ctx;
-clip_ctx_p = c_void_p
+clip_ctx_p = NewType("clip_ctx_p", int)
+clip_ctx_p_ctypes = c_void_p
# struct llava_image_embed {
# float * embed;
@@ -102,43 +97,48 @@ class llava_image_embed(Structure):
# /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
-def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool:
- return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip)
+def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
+ ...
-_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p]
-_libllava.llava_validate_embed_size.restype = c_bool
+llava_validate_embed_size = _libllava.llava_validate_embed_size
+llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes]
+llava_validate_embed_size.restype = c_bool
# /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
-def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]":
- return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length)
+def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
+ ...
-_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int]
-_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
+llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes
+llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int]
+llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
# /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
-def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]":
- return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path)
+def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
+ ...
-_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p]
-_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
+llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename
+llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p]
+llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */
-def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"):
- return _libllava.llava_image_embed_free(embed)
+def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
+ ...
-_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
-_libllava.llava_image_embed_free.restype = None
+llava_image_embed_free = _libllava.llava_image_embed_free
+llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
+llava_image_embed_free.restype = None
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
-def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool:
- return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past)
+def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
+ ...
-_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)]
-_libllava.llava_eval_image_embed.restype = c_bool
+llava_eval_image_embed = _libllava.llava_eval_image_embed
+llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)]
+llava_eval_image_embed.restype = c_bool
################################################
@@ -148,16 +148,18 @@ _libllava.llava_eval_image_embed.restype = c_bool
# /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
-def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p:
- return _libllava.clip_model_load(fname, verbosity)
+def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
+ ...
-_libllava.clip_model_load.argtypes = [c_char_p, c_int]
-_libllava.clip_model_load.restype = clip_ctx_p
+clip_model_load = _libllava.clip_model_load
+clip_model_load.argtypes = [c_char_p, c_int]
+clip_model_load.restype = clip_ctx_p_ctypes
# /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx);
-def clip_free(ctx: clip_ctx_p):
- return _libllava.clip_free(ctx)
+def clip_free(ctx: clip_ctx_p, /):
+ ...
-_libllava.clip_free.argtypes = [clip_ctx_p]
-_libllava.clip_free.restype = None
+clip_free = _libllava.clip_free
+clip_free.argtypes = [clip_ctx_p_ctypes]
+clip_free.restype = None
diff --git a/pyproject.toml b/pyproject.toml
index 4130972..2f3d3ce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -72,4 +72,4 @@ Documentation = "https://llama-cpp-python.readthedocs.io/en/latest/"
Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
[tool.pytest.ini_options]
-addopts = "--ignore=vendor"
+testpaths = "tests"
diff --git a/tests/test_llama.py b/tests/test_llama.py
index dac33b7..5cf421b 100644
--- a/tests/test_llama.py
+++ b/tests/test_llama.py
@@ -54,7 +54,7 @@ def mock_llama(monkeypatch):
output_tokens = llama.tokenize(
output_text.encode("utf-8"), add_bos=True, special=True
)
- logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0)
+ logits = (ctypes.c_float * (n_vocab * n_ctx))(-100.0)
for i in range(n_ctx):
output_idx = i + 1 # logits for first tokens predict second token
if output_idx < len(output_tokens):
@@ -90,9 +90,9 @@ def mock_llama(monkeypatch):
assert n > 0, "mock_llama_decode not called"
assert last_n_tokens > 0, "mock_llama_decode not called"
# Return view of logits for last_n_tokens
- return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address(
+ return (ctypes.c_float * (last_n_tokens * n_vocab)).from_address(
ctypes.addressof(logits)
- + (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float)
+ + (n - last_n_tokens) * n_vocab * ctypes.sizeof(ctypes.c_float)
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
diff --git a/vendor/llama.cpp b/vendor/llama.cpp
index f53119c..15499eb 160000
--- a/vendor/llama.cpp
+++ b/vendor/llama.cpp
@@ -1 +1 @@
-Subproject commit f53119cec4f073b6d214195ecbe1fad3abdf2b34
+Subproject commit 15499eb94227401bdc8875da6eb85c15d37068f7