Merge branch 'main' of github.com:abetlen/llama_cpp_python into better-server-params-and-fields

This commit is contained in:
Andrei Betlen 2023-05-07 01:54:00 -04:00
commit d8fddcce73
13 changed files with 341 additions and 142 deletions

80
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View file

@ -0,0 +1,80 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
# Prerequisites
Please answer the following questions for yourself before submitting an issue.
- [ ] I am running the latest code. Development is very rapid so there are no tagged versions as of now.
- [ ] I carefully followed the [README.md](https://github.com/abetlen/llama-cpp-python/blob/main/README.md).
- [ ] I [searched using keywords relevant to my issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/filtering-and-searching-issues-and-pull-requests) to make sure that I am creating a new issue that is not already open (or closed).
- [ ] I reviewed the [Discussions](https://github.com/abetlen/llama-cpp-python/discussions), and have a new bug or useful enhancement to share.
# Expected Behavior
Please provide a detailed written description of what you were trying to do, and what you expected `llama-cpp-python` to do.
# Current Behavior
Please provide a detailed written description of what `llama-cpp-python` did, instead.
# Environment and Context
Please provide detailed information about your computer setup. This is important in case the issue is not reproducible except for under certain specific conditions.
* Physical (or virtual) hardware you are using, e.g. for Linux:
`$ lscpu`
* Operating System, e.g. for Linux:
`$ uname -a`
* SDK version, e.g. for Linux:
```
$ python3 --version
$ make --version
$ g++ --version
```
# Failure Information (for bugs)
Please help provide information about the failure if this is a bug. If it is not a bug, please remove the rest of this template.
# Steps to Reproduce
Please provide detailed steps for reproducing the issue. We are not sitting in front of your screen, so the more detail the better.
1. step 1
2. step 2
3. step 3
4. etc.
**Note: Many issues seem to be regarding performance issues / differences with `llama.cpp`. In these cases we need to confirm that you're comparing against the version of `llama.cpp` that was built with your python package, and which parameters you're passing to the context.**
# Failure Logs
Please include any relevant log snippets or files. If it works under one configuration but not under another, please provide logs for both configurations and their corresponding outputs so it is easy to see where behavior changes.
Also, please try to **avoid using screenshots** if at all possible. Instead, copy/paste the console output and use [Github's markdown](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) to cleanly format your logs for easy readability.
Example environment info:
```
llama-cpp-python$ git log | head -1
commit 47b0aa6e957b93dbe2c29d53af16fbae2dd628f2
llama-cpp-python$ python3 --version
Python 3.10.10
llama-cpp-python$ pip list | egrep "uvicorn|fastapi|sse-starlette"
fastapi 0.95.0
sse-starlette 1.3.3
uvicorn 0.21.1
```

View file

@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

11
.github/dependabot.yml vendored Normal file
View file

@ -0,0 +1,11 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"

View file

@ -36,4 +36,4 @@ jobs:
push: true # push to registry
pull: true # always fetch the latest base images
platforms: linux/amd64,linux/arm64 # build for both amd64 and arm64
tags: ghcr.io/abetlen/llama-cpp-python:latest
tags: ghcr.io/abetlen/llama-cpp-python:latest

View file

@ -1,4 +1,4 @@
FROM python:3-bullseye
FROM python:3-slim-bullseye
# We need to set the host to 0.0.0.0 to allow outside access
ENV HOST 0.0.0.0
@ -6,10 +6,10 @@ ENV HOST 0.0.0.0
COPY . .
# Install the package
RUN apt update && apt install -y libopenblas-dev
RUN apt update && apt install -y libopenblas-dev ninja-build build-essential
RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette
RUN LLAMA_OPENBLAS=1 python3 setup.py develop
# Run the server
CMD python3 -m llama_cpp.server
CMD python3 -m llama_cpp.server

View file

@ -31,6 +31,10 @@ You can force the use of `cmake` on Linux / MacOS setting the `FORCE_CMAKE=1` en
## High-level API
The high-level API provides a simple managed interface through the `Llama` class.
Below is a short example demonstrating how to use the high-level API to generate text:
```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="./models/7B/ggml-model.bin")
@ -64,12 +68,20 @@ This allows you to use llama.cpp compatible models with any OpenAI compatible cl
To install the server package and get started:
Linux/MacOS
```bash
pip install llama-cpp-python[server]
export MODEL=./models/7B/ggml-model.bin
python3 -m llama_cpp.server
```
Windows
```cmd
pip install llama-cpp-python[server]
SET MODEL=..\models\7B\ggml-model.bin
python3 -m llama_cpp.server
```
Navigate to [http://localhost:8000/docs](http://localhost:8000/docs) to see the OpenAPI documentation.
## Docker image
@ -82,8 +94,25 @@ docker run --rm -it -p8000:8000 -v /path/to/models:/models -eMODEL=/models/ggml-
## Low-level API
The low-level API is a direct `ctypes` binding to the C API provided by `llama.cpp`.
The entire API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and should mirror [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
The low-level API is a direct [`ctypes`](https://docs.python.org/3/library/ctypes.html) binding to the C API provided by `llama.cpp`.
The entire lowe-level API can be found in [llama_cpp/llama_cpp.py](https://github.com/abetlen/llama-cpp-python/blob/master/llama_cpp/llama_cpp.py) and directly mirrors the C API in [llama.h](https://github.com/ggerganov/llama.cpp/blob/master/llama.h).
Below is a short example demonstrating how to use the low-level API to tokenize a prompt:
```python
>>> import llama_cpp
>>> import ctypes
>>> params = llama_cpp.llama_context_default_params()
# use bytes for char * params
>>> ctx = llama_cpp.llama_init_from_file(b"./models/7b/ggml-model.bin", params)
>>> max_tokens = params.n_ctx
# use ctypes arrays for array params
>>> tokens = (llama_cppp.llama_token * int(max_tokens))()
>>> n_tokens = llama_cpp.llama_tokenize(ctx, b"Q: Name the planets in the solar system? A: ", tokens, max_tokens, add_bos=llama_cpp.c_bool(True))
>>> llama_cpp.llama_free(ctx)
```
Check out the [examples folder](examples/low_level_api) for more examples of using the low-level API.
# Documentation

View file

@ -33,12 +33,10 @@ class LlamaCache:
return k
return None
def __getitem__(
self, key: Sequence[llama_cpp.llama_token]
) -> Optional["LlamaState"]:
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
_key = self._find_key(tuple(key))
if _key is None:
return None
raise KeyError(f"Key not found: {key}")
return self.cache_state[_key]
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
@ -53,8 +51,8 @@ class LlamaState:
def __init__(
self,
eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[llama_cpp.c_float]],
llama_state,
eval_logits: Deque[List[float]],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state_size: llama_cpp.c_size_t,
):
self.eval_tokens = eval_tokens
@ -129,7 +127,7 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
self.eval_logits: Deque[List[float]] = deque(
maxlen=n_ctx if logits_all else 1
)
@ -247,7 +245,7 @@ class Llama:
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits: List[List[llama_cpp.c_float]] = [
logits: List[List[float]] = [
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
]
self.eval_logits.extend(logits)
@ -289,7 +287,7 @@ class Llama:
candidates=llama_cpp.ctypes.pointer(candidates),
penalty=repeat_penalty,
)
if temp == 0.0:
if float(temp.value) == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
@ -299,21 +297,25 @@ class Llama:
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
k=top_k,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_tail_free(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
z=llama_cpp.c_float(1.0),
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=llama_cpp.c_float(1.0),
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=top_p,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
@ -390,18 +392,28 @@ class Llama:
"""
assert self.ctx is not None
if (
reset
and len(self.eval_tokens) > 0
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
):
if self.verbose:
print("Llama.generate: cache hit", file=sys.stderr)
reset = False
tokens = tokens[len(self.eval_tokens) :]
if reset and len(self.eval_tokens) > 0:
longest_prefix = 0
for a, b in zip(self.eval_tokens, tokens[:-1]):
if a == b:
longest_prefix += 1
else:
break
if longest_prefix > 0:
if self.verbose:
print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False
tokens = tokens[longest_prefix:]
for _ in range(len(self.eval_tokens) - longest_prefix):
self.eval_tokens.pop()
try:
self.eval_logits.pop()
except IndexError:
pass
if reset:
self.reset()
while True:
self.eval(tokens)
token = self.sample(
@ -639,7 +651,10 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits
]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
@ -958,7 +973,10 @@ class Llama:
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
if self.verbose:
print(f"Llama.save_state: saving {n_bytes} bytes of llama state", file=sys.stderr)
print(
f"Llama.save_state: saving {n_bytes} bytes of llama state",
file=sys.stderr,
)
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
@ -985,7 +1003,7 @@ class Llama:
return llama_cpp.llama_token_bos()
@staticmethod
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
def logits_to_logprobs(logits: List[float]) -> List[float]:
exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps)
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
return [math.log(x / sum_exps) for x in exps]

View file

@ -8,6 +8,7 @@ from ctypes import (
c_void_p,
c_bool,
POINTER,
_Pointer, # type: ignore
Structure,
Array,
c_uint8,
@ -17,7 +18,7 @@ import pathlib
# Load the library
def _load_shared_library(lib_base_name):
def _load_shared_library(lib_base_name: str):
# Determine the file extension based on the platform
if sys.platform.startswith("linux"):
lib_ext = ".so"
@ -67,11 +68,11 @@ _lib_base_name = "llama"
_lib = _load_shared_library(_lib_base_name)
# C types
LLAMA_FILE_VERSION = ctypes.c_int(1)
LLAMA_FILE_VERSION = c_int(1)
LLAMA_FILE_MAGIC = b"ggjt"
LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
LLAMA_SESSION_MAGIC = b"ggsn"
LLAMA_SESSION_VERSION = ctypes.c_int(1)
LLAMA_SESSION_VERSION = c_int(1)
llama_context_p = c_void_p
@ -127,18 +128,23 @@ class llama_context_params(Structure):
llama_context_params_p = POINTER(llama_context_params)
LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(
LLAMA_FTYPE_ALL_F32 = c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(
4
) # tok_embeddings.weight and output.weight are F16
LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes.c_int(5) # except 1d tensors
# LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_0 = ctypes.c_int(7) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_2 = c_int(5) # except 1d tensors
# LLAMA_FTYPE_MOSTYL_Q4_3 = c_int(6) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9) # except 1d tensors
# Misc
c_float_p = POINTER(c_float)
c_uint8_p = POINTER(c_uint8)
c_size_t_p = POINTER(c_size_t)
# Functions
@ -210,8 +216,8 @@ _lib.llama_model_quantize.restype = c_int
# Returns 0 on success
def llama_apply_lora_from_file(
ctx: llama_context_p,
path_lora: ctypes.c_char_p,
path_base_model: ctypes.c_char_p,
path_lora: c_char_p,
path_base_model: c_char_p,
n_threads: c_int,
) -> c_int:
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
@ -252,21 +258,25 @@ _lib.llama_get_state_size.restype = c_size_t
# Copies the state to the specified destination address.
# Destination needs to have allocated enough memory.
# Returns the number of bytes copied
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
def llama_copy_state_data(
ctx: llama_context_p, dest # type: Array[c_uint8]
) -> c_size_t:
return _lib.llama_copy_state_data(ctx, dest)
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p]
_lib.llama_copy_state_data.restype = c_size_t
# Set the state reading from the specified address
# Returns the number of bytes read
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
def llama_set_state_data(
ctx: llama_context_p, src # type: Array[c_uint8]
) -> c_size_t:
return _lib.llama_set_state_data(ctx, src)
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p]
_lib.llama_set_state_data.restype = c_size_t
@ -274,9 +284,9 @@ _lib.llama_set_state_data.restype = c_size_t
def llama_load_session_file(
ctx: llama_context_p,
path_session: bytes,
tokens_out,
tokens_out, # type: Array[llama_token]
n_token_capacity: c_size_t,
n_token_count_out,
n_token_count_out, # type: _Pointer[c_size_t]
) -> c_size_t:
return _lib.llama_load_session_file(
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
@ -288,13 +298,16 @@ _lib.llama_load_session_file.argtypes = [
c_char_p,
llama_token_p,
c_size_t,
POINTER(c_size_t),
c_size_t_p,
]
_lib.llama_load_session_file.restype = c_size_t
def llama_save_session_file(
ctx: llama_context_p, path_session: bytes, tokens, n_token_count: c_size_t
ctx: llama_context_p,
path_session: bytes,
tokens, # type: Array[llama_token]
n_token_count: c_size_t,
) -> c_size_t:
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
@ -374,22 +387,22 @@ _lib.llama_n_embd.restype = c_int
# Can be mutated in order to change the probabilities of the next token
# Rows: n_tokens
# Cols: n_vocab
def llama_get_logits(ctx: llama_context_p):
def llama_get_logits(ctx: llama_context_p): # type: (...) -> Array[float] # type: ignore
return _lib.llama_get_logits(ctx)
_lib.llama_get_logits.argtypes = [llama_context_p]
_lib.llama_get_logits.restype = POINTER(c_float)
_lib.llama_get_logits.restype = c_float_p
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
def llama_get_embeddings(ctx: llama_context_p):
def llama_get_embeddings(ctx: llama_context_p): # type: (...) -> Array[float] # type: ignore
return _lib.llama_get_embeddings(ctx)
_lib.llama_get_embeddings.argtypes = [llama_context_p]
_lib.llama_get_embeddings.restype = POINTER(c_float)
_lib.llama_get_embeddings.restype = c_float_p
# Token Id -> String. Uses the vocabulary in the provided context
@ -433,8 +446,8 @@ _lib.llama_token_nl.restype = llama_token
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
def llama_sample_repetition_penalty(
ctx: llama_context_p,
candidates,
last_tokens_data,
candidates, # type: _Pointer[llama_token_data_array]
last_tokens_data, # type: Array[llama_token]
last_tokens_size: c_int,
penalty: c_float,
):
@ -456,8 +469,8 @@ _lib.llama_sample_repetition_penalty.restype = None
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
def llama_sample_frequency_and_presence_penalties(
ctx: llama_context_p,
candidates,
last_tokens_data,
candidates, # type: _Pointer[llama_token_data_array]
last_tokens_data, # type: Array[llama_token]
last_tokens_size: c_int,
alpha_frequency: c_float,
alpha_presence: c_float,
@ -484,7 +497,9 @@ _lib.llama_sample_frequency_and_presence_penalties.restype = None
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
def llama_sample_softmax(ctx: llama_context_p, candidates):
def llama_sample_softmax(
ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
):
return _lib.llama_sample_softmax(ctx, candidates)
@ -497,7 +512,10 @@ _lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_k(
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
k: c_int,
min_keep: c_size_t,
):
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
@ -513,7 +531,10 @@ _lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_p(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
p: c_float,
min_keep: c_size_t,
):
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
@ -529,7 +550,10 @@ _lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
def llama_sample_tail_free(
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
z: c_float,
min_keep: c_size_t,
):
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
@ -545,7 +569,10 @@ _lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
def llama_sample_typical(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
p: c_float,
min_keep: c_size_t,
):
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
@ -559,7 +586,11 @@ _lib.llama_sample_typical.argtypes = [
_lib.llama_sample_typical.restype = None
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
def llama_sample_temperature(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
temp: c_float,
):
return _lib.llama_sample_temperature(ctx, candidates, temp)
@ -578,7 +609,12 @@ _lib.llama_sample_temperature.restype = None
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
def llama_sample_token_mirostat(
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
tau: c_float,
eta: c_float,
m: c_int,
mu, # type: _Pointer[c_float]
) -> llama_token:
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
@ -589,7 +625,7 @@ _lib.llama_sample_token_mirostat.argtypes = [
c_float,
c_float,
c_int,
POINTER(c_float),
c_float_p,
]
_lib.llama_sample_token_mirostat.restype = llama_token
@ -600,7 +636,11 @@ _lib.llama_sample_token_mirostat.restype = llama_token
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
def llama_sample_token_mirostat_v2(
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
tau: c_float,
eta: c_float,
mu, # type: _Pointer[c_float]
) -> llama_token:
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
@ -610,13 +650,16 @@ _lib.llama_sample_token_mirostat_v2.argtypes = [
llama_token_data_array_p,
c_float,
c_float,
POINTER(c_float),
c_float_p,
]
_lib.llama_sample_token_mirostat_v2.restype = llama_token
# @details Selects the token with the highest probability.
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
def llama_sample_token_greedy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token:
return _lib.llama_sample_token_greedy(ctx, candidates)
@ -628,7 +671,10 @@ _lib.llama_sample_token_greedy.restype = llama_token
# @details Randomly selects a token from the candidates based on their probabilities.
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
def llama_sample_token(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token:
return _lib.llama_sample_token(ctx, candidates)

View file

@ -22,12 +22,26 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import argparse
import uvicorn
from llama_cpp.server.app import create_app
from llama_cpp.server.app import create_app, Settings
if __name__ == "__main__":
app = create_app()
parser = argparse.ArgumentParser()
for name, field in Settings.__fields__.items():
parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
)
args = parser.parse_args()
settings = Settings(**vars(args))
app = create_app(settings=settings)
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

103
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
# This file is automatically @generated by Poetry and should not be changed by hand.
[[package]]
name = "anyio"
@ -21,58 +21,39 @@ doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"]
trio = ["trio (>=0.16,<0.22)"]
[[package]]
name = "attrs"
version = "22.2.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
{file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"},
]
[package.extras]
cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"]
dev = ["attrs[docs,tests]"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"]
tests = ["attrs[tests-no-zope]", "zope.interface"]
tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"]
[[package]]
name = "black"
version = "23.1.0"
version = "23.3.0"
description = "The uncompromising code formatter."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "black-23.1.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221"},
{file = "black-23.1.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26"},
{file = "black-23.1.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b"},
{file = "black-23.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"},
{file = "black-23.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074"},
{file = "black-23.1.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27"},
{file = "black-23.1.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648"},
{file = "black-23.1.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958"},
{file = "black-23.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a"},
{file = "black-23.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481"},
{file = "black-23.1.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad"},
{file = "black-23.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8"},
{file = "black-23.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24"},
{file = "black-23.1.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6"},
{file = "black-23.1.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd"},
{file = "black-23.1.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580"},
{file = "black-23.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468"},
{file = "black-23.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753"},
{file = "black-23.1.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651"},
{file = "black-23.1.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06"},
{file = "black-23.1.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739"},
{file = "black-23.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9"},
{file = "black-23.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555"},
{file = "black-23.1.0-py3-none-any.whl", hash = "sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32"},
{file = "black-23.1.0.tar.gz", hash = "sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac"},
{file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"},
{file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"},
{file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"},
{file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"},
{file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"},
{file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"},
{file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"},
{file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"},
{file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"},
{file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"},
{file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"},
{file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"},
{file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"},
{file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"},
{file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"},
{file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"},
{file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"},
{file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"},
{file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"},
{file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"},
{file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"},
{file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"},
{file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"},
{file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"},
{file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"},
]
[package.dependencies]
@ -747,14 +728,14 @@ files = [
[[package]]
name = "mkdocs"
version = "1.4.2"
version = "1.4.3"
description = "Project documentation with Markdown."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "mkdocs-1.4.2-py3-none-any.whl", hash = "sha256:c8856a832c1e56702577023cd64cc5f84948280c1c0fcc6af4cd39006ea6aa8c"},
{file = "mkdocs-1.4.2.tar.gz", hash = "sha256:8947af423a6d0facf41ea1195b8e1e8c85ad94ac95ae307fe11232e0424b11c5"},
{file = "mkdocs-1.4.3-py3-none-any.whl", hash = "sha256:6ee46d309bda331aac915cd24aab882c179a933bd9e77b80ce7d2eaaa3f689dd"},
{file = "mkdocs-1.4.3.tar.gz", hash = "sha256:5955093bbd4dd2e9403c5afaf57324ad8b04f16886512a3ee6ef828956481c57"},
]
[package.dependencies]
@ -792,14 +773,14 @@ mkdocs = ">=1.1"
[[package]]
name = "mkdocs-material"
version = "9.1.4"
version = "9.1.9"
description = "Documentation that simply works"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "mkdocs_material-9.1.4-py3-none-any.whl", hash = "sha256:4c92dcf9365068259bef3eed8e0dd5410056b6f7187bdea2d52848c0f94cd94c"},
{file = "mkdocs_material-9.1.4.tar.gz", hash = "sha256:c3a8943e9e4a7d2624291da365bbccf0b9f88688aa6947a46260d8c165cd4389"},
{file = "mkdocs_material-9.1.9-py3-none-any.whl", hash = "sha256:7db24261cb17400e132c46d17eea712bfe71056d892a9beba32cf68210297141"},
{file = "mkdocs_material-9.1.9.tar.gz", hash = "sha256:74d8da1371ab3a326868fe47bae3cbc4aa22e93c048b4ca5117e6817b88bd734"},
]
[package.dependencies]
@ -827,14 +808,14 @@ files = [
[[package]]
name = "mkdocstrings"
version = "0.20.0"
version = "0.21.2"
description = "Automatic documentation from sources, for MkDocs."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "mkdocstrings-0.20.0-py3-none-any.whl", hash = "sha256:f17fc2c4f760ec302b069075ef9e31045aa6372ca91d2f35ded3adba8e25a472"},
{file = "mkdocstrings-0.20.0.tar.gz", hash = "sha256:c757f4f646d4f939491d6bc9256bfe33e36c5f8026392f49eaa351d241c838e5"},
{file = "mkdocstrings-0.21.2-py3-none-any.whl", hash = "sha256:949ef8da92df9d692ca07be50616459a6b536083a25520fd54b00e8814ce019b"},
{file = "mkdocstrings-0.21.2.tar.gz", hash = "sha256:304e56a2e90595708a38a13a278e538a67ad82052dd5c8b71f77a604a4f3d911"},
]
[package.dependencies]
@ -845,6 +826,7 @@ mkdocs = ">=1.2"
mkdocs-autorefs = ">=0.3.1"
mkdocstrings-python = {version = ">=0.5.2", optional = true, markers = "extra == \"python\""}
pymdown-extensions = ">=6.3"
typing-extensions = {version = ">=4.1", markers = "python_version < \"3.10\""}
[package.extras]
crystal = ["mkdocstrings-crystal (>=0.3.4)"]
@ -1007,18 +989,17 @@ pyyaml = "*"
[[package]]
name = "pytest"
version = "7.2.2"
version = "7.3.1"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"},
{file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"},
{file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"},
{file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"},
]
[package.dependencies]
attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
@ -1027,7 +1008,7 @@ pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "python-dateutil"
@ -1458,4 +1439,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
content-hash = "aa15e57300668bd23c051b4cd87bec4c1a58dcccd2f2b4767579fea7f2c5fa41"
content-hash = "e87403dcd0a0b8484436b02c392326adfaf22b8d7e182d77e4a155c67a7435bc"

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "llama_cpp_python"
version = "0.1.41"
version = "0.1.43"
description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT"
@ -18,12 +18,12 @@ typing-extensions = "^4.5.0"
[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
black = "^23.3.0"
twine = "^4.0.2"
mkdocs = "^1.4.2"
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
mkdocs-material = "^9.1.4"
pytest = "^7.2.2"
mkdocs = "^1.4.3"
mkdocstrings = {extras = ["python"], version = "^0.21.2"}
mkdocs-material = "^9.1.9"
pytest = "^7.3.1"
httpx = "^0.24.0"
[build-system]

View file

@ -10,7 +10,7 @@ setup(
description="A Python wrapper for llama.cpp",
long_description=long_description,
long_description_content_type="text/markdown",
version="0.1.41",
version="0.1.43",
author="Andrei Betlen",
author_email="abetlen@gmail.com",
license="MIT",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit e216aa04633892b972d013719e38b59fd4917341
Subproject commit 1b0fd454650ef4d68a980e3225488b79e6e9af25