This commit is contained in:
Mug 2023-04-17 14:45:42 +02:00
commit 1b73a15e62
7 changed files with 244 additions and 55 deletions

View file

@ -104,10 +104,13 @@ python3 setup.py develop
- create_completion - create_completion
- __call__ - __call__
- create_chat_completion - create_chat_completion
- set_cache
- token_bos - token_bos
- token_eos - token_eos
show_root_heading: true show_root_heading: true
::: llama_cpp.LlamaCache
::: llama_cpp.llama_cpp ::: llama_cpp.llama_cpp
options: options:
show_if_no_docstring: true show_if_no_docstring: true

View file

@ -2,6 +2,7 @@ import os
import sys import sys
import uuid import uuid
import time import time
import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque from collections import deque
@ -10,6 +11,15 @@ from . import llama_cpp
from .llama_types import * from .llama_types import *
class LlamaCache:
"""Cache for a llama.cpp model.
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
completion. It does not actually cache the results."""
pass
class Llama: class Llama:
"""High-level Python wrapper for a llama.cpp model.""" """High-level Python wrapper for a llama.cpp model."""
@ -20,7 +30,7 @@ class Llama:
n_ctx: int = 512, n_ctx: int = 512,
n_parts: int = -1, n_parts: int = -1,
seed: int = 1337, seed: int = 1337,
f16_kv: bool = False, f16_kv: bool = True,
logits_all: bool = False, logits_all: bool = False,
vocab_only: bool = False, vocab_only: bool = False,
use_mmap: bool = True, use_mmap: bool = True,
@ -75,7 +85,19 @@ class Llama:
maxlen=self.last_n_tokens_size, maxlen=self.last_n_tokens_size,
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.tokens: List[llama_cpp.llama_token] = []
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
@ -130,12 +152,24 @@ class Llama:
output += llama_cpp.llama_token_to_str(self.ctx, token) output += llama_cpp.llama_token_to_str(self.ctx, token)
return output return output
def set_cache(self, cache: Optional[LlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self._cache = cache
def reset(self): def reset(self):
"""Reset the model state.""" """Reset the model state."""
self.last_n_tokens_data.extend( self.last_n_tokens_data.extend(
[llama_cpp.llama_token(0)] * self.last_n_tokens_size [llama_cpp.llama_token(0)] * self.last_n_tokens_size
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.tokens.clear()
self.n_tokens = 0
self.n_past = 0
self.all_logits.clear()
def eval(self, tokens: Sequence[llama_cpp.llama_token]): def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens. """Evaluate a list of tokens.
@ -147,18 +181,32 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), self.tokens_consumed) self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_tokens = len(batch)
return_code = llama_cpp.llama_eval( return_code = llama_cpp.llama_eval(
ctx=self.ctx, ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch), tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(len(batch)), n_tokens=llama_cpp.c_int(self.n_tokens),
n_past=llama_cpp.c_int(n_past), n_past=llama_cpp.c_int(self.n_past),
n_threads=llama_cpp.c_int(self.n_threads), n_threads=llama_cpp.c_int(self.n_threads),
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
self.tokens.extend(batch)
self.last_n_tokens_data.extend(batch) self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch) self.tokens_consumed += len(batch)
if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
return logits
def sample( def sample(
self, self,
@ -198,6 +246,7 @@ class Llama:
top_p: float, top_p: float,
temp: float, temp: float,
repeat_penalty: float, repeat_penalty: float,
reset: bool = True,
) -> Generator[ ) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]: ]:
@ -215,12 +264,26 @@ class Llama:
top_p: The top-p sampling parameter. top_p: The top-p sampling parameter.
temp: The temperature parameter. temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter. repeat_penalty: The repeat penalty parameter.
reset: Whether to reset the model state.
Yields: Yields:
The generated tokens. The generated tokens.
""" """
assert self.ctx is not None assert self.ctx is not None
self.reset() ### HACK
if (
reset
and self._cache
and len(self.tokens) > 0
and self.tokens == tokens[: len(self.tokens)]
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
reset = False
tokens = tokens[len(self.tokens) :]
###
if reset:
self.reset()
while True: while True:
self.eval(tokens) self.eval(tokens)
token = self.sample( token = self.sample(
@ -300,19 +363,22 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time()) created: int = int(time.time())
completion_tokens: List[llama_cpp.llama_token] = [] completion_tokens: List[llama_cpp.llama_token] = []
# Add blank space to start of prompt to match OG llama tokenizer # Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8")) prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
text = b"" b" " + prompt.encode("utf-8")
returned_characters = 0 )
text: bytes = b""
returned_characters: int = 0
stop = stop if stop is not None else []
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
@ -327,13 +393,34 @@ class Llama:
else: else:
stop_sequences = [] stop_sequences = []
finish_reason = None if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
### HACK
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.verbose:
print("completion cache hit", file=sys.stderr)
reset = False
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###
finish_reason = "length"
for token in self.generate( for token in self.generate(
prompt_tokens, prompt_tokens,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temp=temperature, temp=temperature,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
reset=reset,
): ):
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -363,6 +450,9 @@ class Llama:
break break
text = all_text[: len(all_text) - longest] text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:]) returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -377,15 +467,16 @@ class Llama:
} }
], ],
} }
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "length" finish_reason = "length"
break break
if finish_reason is None:
finish_reason = "length"
if stream: if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -402,16 +493,57 @@ class Llama:
} }
return return
text = text.decode("utf-8") ### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8")
if echo: if echo:
text = prompt + text text_str = prompt + text_str
if suffix is not None: if suffix is not None:
text = text + suffix text_str = text_str + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None: if logprobs is not None:
raise NotImplementedError("logprobs not implemented") text_offset = 0
text_offsets: List[int] = []
token_logprobs: List[float] = []
tokens: List[str] = []
top_logprobs: List[Dict[str, float]] = []
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8") for token in all_tokens
]
all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits
]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
top_logprobs.append(top_logprob)
logprobs_or_none = {
"tokens": tokens,
"text_offset": text_offsets,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
}
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self.ctx) llama_cpp.llama_print_timings(self.ctx)
@ -423,9 +555,9 @@ class Llama:
"model": self.model_path, "model": self.model_path,
"choices": [ "choices": [
{ {
"text": text, "text": text_str,
"index": 0, "index": 0,
"logprobs": None, "logprobs": logprobs_or_none,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
], ],
@ -445,7 +577,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -500,7 +632,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -602,12 +734,12 @@ class Llama:
def create_chat_completion( def create_chat_completion(
self, self,
messages: List[ChatCompletionMessage], messages: List[ChatCompletionMessage],
temperature: float = 0.8, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
max_tokens: int = 128, max_tokens: int = 256,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -625,13 +757,13 @@ class Llama:
Returns: Returns:
Generated chat completion or a stream of chat completion chunks. Generated chat completion or a stream of chat completion chunks.
""" """
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions.""" stop = stop if stop is not None else []
chat_history = "\n".join( chat_history = "".join(
f'{message["role"]} {message.get("user", "")}: {message["content"]}' f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages for message in messages
) )
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: " PROMPT = chat_history + "### Assistant:"
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "] PROMPT_STOP = ["### Assistant:", "### Human:"]
completion_or_chunks = self( completion_or_chunks = self(
prompt=PROMPT, prompt=PROMPT,
stop=PROMPT_STOP + stop, stop=PROMPT_STOP + stop,
@ -668,8 +800,6 @@ class Llama:
use_mlock=self.params.use_mlock, use_mlock=self.params.use_mlock,
embedding=self.params.embedding, embedding=self.params.embedding,
last_n_tokens_size=self.last_n_tokens_size, last_n_tokens_size=self.last_n_tokens_size,
last_n_tokens_data=self.last_n_tokens_data,
tokens_consumed=self.tokens_consumed,
n_batch=self.n_batch, n_batch=self.n_batch,
n_threads=self.n_threads, n_threads=self.n_threads,
) )
@ -691,9 +821,6 @@ class Llama:
last_n_tokens_size=state["last_n_tokens_size"], last_n_tokens_size=state["last_n_tokens_size"],
verbose=state["verbose"], verbose=state["verbose"],
) )
self.last_n_tokens_data = state["last_n_tokens_data"]
self.tokens_consumed = state["tokens_consumed"]
@staticmethod @staticmethod
def token_eos() -> llama_cpp.llama_token: def token_eos() -> llama_cpp.llama_token:
@ -704,3 +831,7 @@ class Llama:
def token_bos() -> llama_cpp.llama_token: def token_bos() -> llama_cpp.llama_token:
"""Return the beginning-of-sequence token.""" """Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos() return llama_cpp.llama_token_bos()
@staticmethod
def logit_to_logprob(x: float) -> float:
return math.log(1.0 + math.exp(x))

View file

@ -1,9 +1,21 @@
import sys import sys
import os import os
import ctypes import ctypes
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t from ctypes import (
c_int,
c_float,
c_char_p,
c_void_p,
c_bool,
POINTER,
Structure,
Array,
c_uint8,
c_size_t,
)
import pathlib import pathlib
# Load the library # Load the library
def _load_shared_library(lib_base_name): def _load_shared_library(lib_base_name):
# Determine the file extension based on the platform # Determine the file extension based on the platform
@ -22,9 +34,15 @@ def _load_shared_library(lib_base_name):
# for llamacpp) and "llama" (default name for this repo) # for llamacpp) and "llama" (default name for this repo)
_lib_paths = [ _lib_paths = [
_base_path / f"lib{lib_base_name}{lib_ext}", _base_path / f"lib{lib_base_name}{lib_ext}",
_base_path / f"{lib_base_name}{lib_ext}" _base_path / f"{lib_base_name}{lib_ext}",
] ]
if "LLAMA_CPP_LIB" in os.environ:
lib_base_name = os.environ["LLAMA_CPP_LIB"]
_lib = pathlib.Path(lib_base_name)
_base_path = _lib.parent.resolve()
_lib_paths = [_lib.resolve()]
# Add the library directory to the DLL search path on Windows (if needed) # Add the library directory to the DLL search path on Windows (if needed)
if sys.platform == "win32" and sys.version_info >= (3, 8): if sys.platform == "win32" and sys.version_info >= (3, 8):
os.add_dll_directory(str(_base_path)) os.add_dll_directory(str(_base_path))
@ -37,7 +55,10 @@ def _load_shared_library(lib_base_name):
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found") raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
# Specify the base name of the shared library to load # Specify the base name of the shared library to load
_lib_base_name = "llama" _lib_base_name = "llama"
@ -89,6 +110,11 @@ class llama_context_params(Structure):
llama_context_params_p = POINTER(llama_context_params) llama_context_params_p = POINTER(llama_context_params)
LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(4) # tok_embeddings.weight and output.weight are F16
# Functions # Functions
@ -100,18 +126,23 @@ def llama_context_default_params() -> llama_context_params:
_lib.llama_context_default_params.argtypes = [] _lib.llama_context_default_params.argtypes = []
_lib.llama_context_default_params.restype = llama_context_params _lib.llama_context_default_params.restype = llama_context_params
def llama_mmap_supported() -> c_bool: def llama_mmap_supported() -> c_bool:
return _lib.llama_mmap_supported() return _lib.llama_mmap_supported()
_lib.llama_mmap_supported.argtypes = [] _lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool _lib.llama_mmap_supported.restype = c_bool
def llama_mlock_supported() -> c_bool: def llama_mlock_supported() -> c_bool:
return _lib.llama_mlock_supported() return _lib.llama_mlock_supported()
_lib.llama_mlock_supported.argtypes = [] _lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool _lib.llama_mlock_supported.restype = c_bool
# Various functions for loading a ggml llama model. # Various functions for loading a ggml llama model.
# Allocate (almost) all memory needed for the model. # Allocate (almost) all memory needed for the model.
# Return NULL on failure # Return NULL on failure
@ -136,42 +167,49 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change # TODO: not great API - very likely to change
# Returns 0 on success # Returns 0 on success
def llama_model_quantize( def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int) -> c_int:
fname_inp: bytes, fname_out: bytes, itype: c_int
) -> c_int:
return _lib.llama_model_quantize(fname_inp, fname_out, itype) return _lib.llama_model_quantize(fname_inp, fname_out, itype)
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int] _lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
_lib.llama_model_quantize.restype = c_int _lib.llama_model_quantize.restype = c_int
# Returns the KV cache that will contain the context for the # Returns the KV cache that will contain the context for the
# ongoing prediction with the model. # ongoing prediction with the model.
def llama_get_kv_cache(ctx: llama_context_p): def llama_get_kv_cache(ctx: llama_context_p):
return _lib.llama_get_kv_cache(ctx) return _lib.llama_get_kv_cache(ctx)
_lib.llama_get_kv_cache.argtypes = [llama_context_p] _lib.llama_get_kv_cache.argtypes = [llama_context_p]
_lib.llama_get_kv_cache.restype = POINTER(c_uint8) _lib.llama_get_kv_cache.restype = POINTER(c_uint8)
# Returns the size of the KV cache # Returns the size of the KV cache
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t: def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
return _lib.llama_get_kv_cache_size(ctx) return _lib.llama_get_kv_cache_size(ctx)
_lib.llama_get_kv_cache_size.argtypes = [llama_context_p] _lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_size.restype = c_size_t _lib.llama_get_kv_cache_size.restype = c_size_t
# Returns the number of tokens in the KV cache # Returns the number of tokens in the KV cache
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
return _lib.llama_get_kv_cache_token_count(ctx) return _lib.llama_get_kv_cache_token_count(ctx)
_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int _lib.llama_get_kv_cache_token_count.restype = c_int
# Sets the KV cache containing the current context for the model # Sets the KV cache containing the current context for the model
def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int): def llama_set_kv_cache(
ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int
):
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count) return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)
_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int] _lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int]
_lib.llama_set_kv_cache.restype = None _lib.llama_set_kv_cache.restype = None

View file

@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
""" """
import os import os
import json import json
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict from typing_extensions import TypedDict
import llama_cpp import llama_cpp
from fastapi import FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -33,6 +34,8 @@ class Settings(BaseSettings):
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True embedding: bool = True
last_n_tokens_size: int = 64 last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature
app = FastAPI( app = FastAPI(
@ -52,11 +55,21 @@ llama = llama_cpp.Llama(
f16_kv=settings.f16_kv, f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock, use_mlock=settings.use_mlock,
embedding=settings.embedding, embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads, n_threads=settings.n_threads,
n_batch=settings.n_batch, n_batch=settings.n_batch,
n_ctx=settings.n_ctx, n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size, last_n_tokens_size=settings.last_n_tokens_size,
) )
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel): class CreateCompletionRequest(BaseModel):
@ -66,7 +79,7 @@ class CreateCompletionRequest(BaseModel):
temperature: float = 0.8 temperature: float = 0.8
top_p: float = 0.95 top_p: float = 0.95
echo: bool = False echo: bool = False
stop: List[str] = [] stop: Optional[List[str]] = []
stream: bool = False stream: bool = False
# ignored or currently unsupported # ignored or currently unsupported
@ -99,7 +112,9 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
def create_completion(request: CreateCompletionRequest): def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt) request.prompt = "".join(request.prompt)
@ -108,7 +123,6 @@ def create_completion(request: CreateCompletionRequest):
exclude={ exclude={
"model", "model",
"n", "n",
"logprobs",
"frequency_penalty", "frequency_penalty",
"presence_penalty", "presence_penalty",
"best_of", "best_of",
@ -144,7 +158,9 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings", "/v1/embeddings",
response_model=CreateEmbeddingResponse, response_model=CreateEmbeddingResponse,
) )
def create_embedding(request: CreateEmbeddingRequest): def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"})) return llama.create_embedding(**request.dict(exclude={"model", "user"}))
@ -160,7 +176,7 @@ class CreateChatCompletionRequest(BaseModel):
temperature: float = 0.8 temperature: float = 0.8
top_p: float = 0.95 top_p: float = 0.95
stream: bool = False stream: bool = False
stop: List[str] = [] stop: Optional[List[str]] = []
max_tokens: int = 128 max_tokens: int = 128
# ignored or currently unsupported # ignored or currently unsupported
@ -196,8 +212,9 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
"/v1/chat/completions", "/v1/chat/completions",
response_model=CreateChatCompletionResponse, response_model=CreateChatCompletionResponse,
) )
async def create_chat_completion( def create_chat_completion(
request: CreateChatCompletionRequest, request: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion( completion_or_chunks = llama.create_chat_completion(
**request.dict( **request.dict(

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llama_cpp_python" name = "llama_cpp_python"
version = "0.1.30" version = "0.1.34"
description = "Python bindings for the llama.cpp library" description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"] authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT" license = "MIT"

View file

@ -3,14 +3,14 @@ from skbuild import setup
from pathlib import Path from pathlib import Path
this_directory = Path(__file__).parent this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text() long_description = (this_directory / "README.md").read_text(encoding="utf-8")
setup( setup(
name="llama_cpp_python", name="llama_cpp_python",
description="A Python wrapper for llama.cpp", description="A Python wrapper for llama.cpp",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
version="0.1.30", version="0.1.34",
author="Andrei Betlen", author="Andrei Betlen",
author_email="abetlen@gmail.com", author_email="abetlen@gmail.com",
license="MIT", license="MIT",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 180b693a47b6b825288ef9f2c39d24b6eea4eea6 Subproject commit e95b6554b493e71a0275764342e09bd5784a7026