This commit is contained in:
Mug 2023-04-17 14:45:42 +02:00
commit 1b73a15e62
7 changed files with 244 additions and 55 deletions

View file

@ -104,10 +104,13 @@ python3 setup.py develop
- create_completion
- __call__
- create_chat_completion
- set_cache
- token_bos
- token_eos
show_root_heading: true
::: llama_cpp.LlamaCache
::: llama_cpp.llama_cpp
options:
show_if_no_docstring: true

View file

@ -2,6 +2,7 @@ import os
import sys
import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque
@ -10,6 +11,15 @@ from . import llama_cpp
from .llama_types import *
class LlamaCache:
"""Cache for a llama.cpp model.
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
completion. It does not actually cache the results."""
pass
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
@ -20,7 +30,7 @@ class Llama:
n_ctx: int = 512,
n_parts: int = -1,
seed: int = 1337,
f16_kv: bool = False,
f16_kv: bool = True,
logits_all: bool = False,
vocab_only: bool = False,
use_mmap: bool = True,
@ -75,7 +85,19 @@ class Llama:
maxlen=self.last_n_tokens_size,
)
self.tokens_consumed = 0
self.tokens: List[llama_cpp.llama_token] = []
self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
@ -130,12 +152,24 @@ class Llama:
output += llama_cpp.llama_token_to_str(self.ctx, token)
return output
def set_cache(self, cache: Optional[LlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self._cache = cache
def reset(self):
"""Reset the model state."""
self.last_n_tokens_data.extend(
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
)
self.tokens_consumed = 0
self.tokens.clear()
self.n_tokens = 0
self.n_past = 0
self.all_logits.clear()
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens.
@ -147,18 +181,32 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_tokens = len(batch)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(len(batch)),
n_past=llama_cpp.c_int(n_past),
n_tokens=llama_cpp.c_int(self.n_tokens),
n_past=llama_cpp.c_int(self.n_past),
n_threads=llama_cpp.c_int(self.n_threads),
)
if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
self.tokens.extend(batch)
self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch)
if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
return logits
def sample(
self,
@ -198,6 +246,7 @@ class Llama:
top_p: float,
temp: float,
repeat_penalty: float,
reset: bool = True,
) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]:
@ -215,12 +264,26 @@ class Llama:
top_p: The top-p sampling parameter.
temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter.
reset: Whether to reset the model state.
Yields:
The generated tokens.
"""
assert self.ctx is not None
self.reset()
### HACK
if (
reset
and self._cache
and len(self.tokens) > 0
and self.tokens == tokens[: len(self.tokens)]
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
reset = False
tokens = tokens[len(self.tokens) :]
###
if reset:
self.reset()
while True:
self.eval(tokens)
token = self.sample(
@ -300,19 +363,22 @@ class Llama:
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: List[str] = [],
stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time())
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[llama_cpp.llama_token] = []
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8"))
text = b""
returned_characters = 0
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
b" " + prompt.encode("utf-8")
)
text: bytes = b""
returned_characters: int = 0
stop = stop if stop is not None else []
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
@ -327,13 +393,34 @@ class Llama:
else:
stop_sequences = []
finish_reason = None
if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
### HACK
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.verbose:
print("completion cache hit", file=sys.stderr)
reset = False
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###
finish_reason = "length"
for token in self.generate(
prompt_tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
reset=reset,
):
if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens)
@ -363,6 +450,9 @@ class Llama:
break
text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield {
"id": completion_id,
"object": "text_completion",
@ -377,15 +467,16 @@ class Llama:
}
],
}
if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)
finish_reason = "length"
break
if finish_reason is None:
finish_reason = "length"
if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield {
"id": completion_id,
"object": "text_completion",
@ -402,16 +493,57 @@ class Llama:
}
return
text = text.decode("utf-8")
### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8")
if echo:
text = prompt + text
text_str = prompt + text_str
if suffix is not None:
text = text + suffix
text_str = text_str + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
raise NotImplementedError("logprobs not implemented")
text_offset = 0
text_offsets: List[int] = []
token_logprobs: List[float] = []
tokens: List[str] = []
top_logprobs: List[Dict[str, float]] = []
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8") for token in all_tokens
]
all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits
]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
top_logprobs.append(top_logprob)
logprobs_or_none = {
"tokens": tokens,
"text_offset": text_offsets,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
}
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
@ -423,9 +555,9 @@ class Llama:
"model": self.model_path,
"choices": [
{
"text": text,
"text": text_str,
"index": 0,
"logprobs": None,
"logprobs": logprobs_or_none,
"finish_reason": finish_reason,
}
],
@ -445,7 +577,7 @@ class Llama:
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: List[str] = [],
stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
@ -500,7 +632,7 @@ class Llama:
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: List[str] = [],
stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
@ -602,12 +734,12 @@ class Llama:
def create_chat_completion(
self,
messages: List[ChatCompletionMessage],
temperature: float = 0.8,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: List[str] = [],
max_tokens: int = 128,
stop: Optional[List[str]] = [],
max_tokens: int = 256,
repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
@ -625,13 +757,13 @@ class Llama:
Returns:
Generated chat completion or a stream of chat completion chunks.
"""
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
chat_history = "\n".join(
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
stop = stop if stop is not None else []
chat_history = "".join(
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages
)
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
PROMPT = chat_history + "### Assistant:"
PROMPT_STOP = ["### Assistant:", "### Human:"]
completion_or_chunks = self(
prompt=PROMPT,
stop=PROMPT_STOP + stop,
@ -668,8 +800,6 @@ class Llama:
use_mlock=self.params.use_mlock,
embedding=self.params.embedding,
last_n_tokens_size=self.last_n_tokens_size,
last_n_tokens_data=self.last_n_tokens_data,
tokens_consumed=self.tokens_consumed,
n_batch=self.n_batch,
n_threads=self.n_threads,
)
@ -691,9 +821,6 @@ class Llama:
last_n_tokens_size=state["last_n_tokens_size"],
verbose=state["verbose"],
)
self.last_n_tokens_data = state["last_n_tokens_data"]
self.tokens_consumed = state["tokens_consumed"]
@staticmethod
def token_eos() -> llama_cpp.llama_token:
@ -704,3 +831,7 @@ class Llama:
def token_bos() -> llama_cpp.llama_token:
"""Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos()
@staticmethod
def logit_to_logprob(x: float) -> float:
return math.log(1.0 + math.exp(x))

View file

@ -1,9 +1,21 @@
import sys
import os
import ctypes
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
from ctypes import (
c_int,
c_float,
c_char_p,
c_void_p,
c_bool,
POINTER,
Structure,
Array,
c_uint8,
c_size_t,
)
import pathlib
# Load the library
def _load_shared_library(lib_base_name):
# Determine the file extension based on the platform
@ -22,9 +34,15 @@ def _load_shared_library(lib_base_name):
# for llamacpp) and "llama" (default name for this repo)
_lib_paths = [
_base_path / f"lib{lib_base_name}{lib_ext}",
_base_path / f"{lib_base_name}{lib_ext}"
_base_path / f"{lib_base_name}{lib_ext}",
]
if "LLAMA_CPP_LIB" in os.environ:
lib_base_name = os.environ["LLAMA_CPP_LIB"]
_lib = pathlib.Path(lib_base_name)
_base_path = _lib.parent.resolve()
_lib_paths = [_lib.resolve()]
# Add the library directory to the DLL search path on Windows (if needed)
if sys.platform == "win32" and sys.version_info >= (3, 8):
os.add_dll_directory(str(_base_path))
@ -37,7 +55,10 @@ def _load_shared_library(lib_base_name):
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
# Specify the base name of the shared library to load
_lib_base_name = "llama"
@ -89,6 +110,11 @@ class llama_context_params(Structure):
llama_context_params_p = POINTER(llama_context_params)
LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(4) # tok_embeddings.weight and output.weight are F16
# Functions
@ -100,18 +126,23 @@ def llama_context_default_params() -> llama_context_params:
_lib.llama_context_default_params.argtypes = []
_lib.llama_context_default_params.restype = llama_context_params
def llama_mmap_supported() -> c_bool:
return _lib.llama_mmap_supported()
_lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool
def llama_mlock_supported() -> c_bool:
return _lib.llama_mlock_supported()
_lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool
# Various functions for loading a ggml llama model.
# Allocate (almost) all memory needed for the model.
# Return NULL on failure
@ -136,42 +167,49 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change
# Returns 0 on success
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, itype: c_int
) -> c_int:
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int) -> c_int:
return _lib.llama_model_quantize(fname_inp, fname_out, itype)
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
_lib.llama_model_quantize.restype = c_int
# Returns the KV cache that will contain the context for the
# ongoing prediction with the model.
def llama_get_kv_cache(ctx: llama_context_p):
return _lib.llama_get_kv_cache(ctx)
_lib.llama_get_kv_cache.argtypes = [llama_context_p]
_lib.llama_get_kv_cache.restype = POINTER(c_uint8)
# Returns the size of the KV cache
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
return _lib.llama_get_kv_cache_size(ctx)
_lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_size.restype = c_size_t
# Returns the number of tokens in the KV cache
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
return _lib.llama_get_kv_cache_token_count(ctx)
_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int
# Sets the KV cache containing the current context for the model
def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int):
def llama_set_kv_cache(
ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int
):
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)
_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int]
_lib.llama_set_kv_cache.restype = None

View file

@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import json
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
@ -33,6 +34,8 @@ class Settings(BaseSettings):
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True
last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature
app = FastAPI(
@ -52,11 +55,21 @@ llama = llama_cpp.Llama(
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads,
n_batch=settings.n_batch,
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel):
@ -66,7 +79,7 @@ class CreateCompletionRequest(BaseModel):
temperature: float = 0.8
top_p: float = 0.95
echo: bool = False
stop: List[str] = []
stop: Optional[List[str]] = []
stream: bool = False
# ignored or currently unsupported
@ -99,7 +112,9 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(request: CreateCompletionRequest):
def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt)
@ -108,7 +123,6 @@ def create_completion(request: CreateCompletionRequest):
exclude={
"model",
"n",
"logprobs",
"frequency_penalty",
"presence_penalty",
"best_of",
@ -144,7 +158,9 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(request: CreateEmbeddingRequest):
def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
@ -160,7 +176,7 @@ class CreateChatCompletionRequest(BaseModel):
temperature: float = 0.8
top_p: float = 0.95
stream: bool = False
stop: List[str] = []
stop: Optional[List[str]] = []
max_tokens: int = 128
# ignored or currently unsupported
@ -196,8 +212,9 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
async def create_chat_completion(
def create_chat_completion(
request: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "llama_cpp_python"
version = "0.1.30"
version = "0.1.34"
description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT"

View file

@ -3,14 +3,14 @@ from skbuild import setup
from pathlib import Path
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
setup(
name="llama_cpp_python",
description="A Python wrapper for llama.cpp",
long_description=long_description,
long_description_content_type="text/markdown",
version="0.1.30",
version="0.1.34",
author="Andrei Betlen",
author_email="abetlen@gmail.com",
license="MIT",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 180b693a47b6b825288ef9f2c39d24b6eea4eea6
Subproject commit e95b6554b493e71a0275764342e09bd5784a7026