Merge branch 'main' of github.com:abetlen/llama-cpp-python

This commit is contained in:
Niek van der Maas 2023-04-15 20:24:59 +02:00
commit 6df27b2da0
6 changed files with 194 additions and 46 deletions

View file

@ -2,6 +2,7 @@ import os
import sys import sys
import uuid import uuid
import time import time
import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque from collections import deque
@ -10,6 +11,15 @@ from . import llama_cpp
from .llama_types import * from .llama_types import *
class LlamaCache:
"""Cache for a llama.cpp model.
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
completion. It does not actually cache the results."""
pass
class Llama: class Llama:
"""High-level Python wrapper for a llama.cpp model.""" """High-level Python wrapper for a llama.cpp model."""
@ -20,7 +30,7 @@ class Llama:
n_ctx: int = 512, n_ctx: int = 512,
n_parts: int = -1, n_parts: int = -1,
seed: int = 1337, seed: int = 1337,
f16_kv: bool = False, f16_kv: bool = True,
logits_all: bool = False, logits_all: bool = False,
vocab_only: bool = False, vocab_only: bool = False,
use_mmap: bool = True, use_mmap: bool = True,
@ -75,7 +85,19 @@ class Llama:
maxlen=self.last_n_tokens_size, maxlen=self.last_n_tokens_size,
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.tokens: List[llama_cpp.llama_token] = []
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
@ -130,12 +152,24 @@ class Llama:
output += llama_cpp.llama_token_to_str(self.ctx, token) output += llama_cpp.llama_token_to_str(self.ctx, token)
return output return output
def set_cache(self, cache: Optional[LlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self._cache = cache
def reset(self): def reset(self):
"""Reset the model state.""" """Reset the model state."""
self.last_n_tokens_data.extend( self.last_n_tokens_data.extend(
[llama_cpp.llama_token(0)] * self.last_n_tokens_size [llama_cpp.llama_token(0)] * self.last_n_tokens_size
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.tokens.clear()
self.n_tokens = 0
self.n_past = 0
self.all_logits.clear()
def eval(self, tokens: Sequence[llama_cpp.llama_token]): def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens. """Evaluate a list of tokens.
@ -147,18 +181,32 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), self.tokens_consumed) self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_tokens = len(batch)
return_code = llama_cpp.llama_eval( return_code = llama_cpp.llama_eval(
ctx=self.ctx, ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch), tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(len(batch)), n_tokens=llama_cpp.c_int(self.n_tokens),
n_past=llama_cpp.c_int(n_past), n_past=llama_cpp.c_int(self.n_past),
n_threads=llama_cpp.c_int(self.n_threads), n_threads=llama_cpp.c_int(self.n_threads),
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
self.tokens.extend(batch)
self.last_n_tokens_data.extend(batch) self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch) self.tokens_consumed += len(batch)
if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
return logits
def sample( def sample(
self, self,
@ -198,6 +246,7 @@ class Llama:
top_p: float, top_p: float,
temp: float, temp: float,
repeat_penalty: float, repeat_penalty: float,
reset: bool = True,
) -> Generator[ ) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]: ]:
@ -215,12 +264,25 @@ class Llama:
top_p: The top-p sampling parameter. top_p: The top-p sampling parameter.
temp: The temperature parameter. temp: The temperature parameter.
repeat_penalty: The repeat penalty parameter. repeat_penalty: The repeat penalty parameter.
reset: Whether to reset the model state.
Yields: Yields:
The generated tokens. The generated tokens.
""" """
assert self.ctx is not None assert self.ctx is not None
self.reset() ### HACK
if (
reset
and self._cache
and len(self.tokens) > 0
and self.tokens == tokens[: len(self.tokens)]
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
reset = False
###
if reset:
self.reset()
while True: while True:
self.eval(tokens) self.eval(tokens)
token = self.sample( token = self.sample(
@ -300,19 +362,22 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
completion_id = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created = int(time.time()) created: int = int(time.time())
completion_tokens: List[llama_cpp.llama_token] = [] completion_tokens: List[llama_cpp.llama_token] = []
# Add blank space to start of prompt to match OG llama tokenizer # Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8")) prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
text = b"" b" " + prompt.encode("utf-8")
returned_characters = 0 )
text: bytes = b""
returned_characters: int = 0
stop = stop if stop is not None else []
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
@ -327,13 +392,34 @@ class Llama:
else: else:
stop_sequences = [] stop_sequences = []
finish_reason = None if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
### HACK
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.verbose:
print("completion cache hit", file=sys.stderr)
reset = False
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###
finish_reason = "length"
for token in self.generate( for token in self.generate(
prompt_tokens, prompt_tokens,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temp=temperature, temp=temperature,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
reset=reset,
): ):
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -363,6 +449,9 @@ class Llama:
break break
text = all_text[: len(all_text) - longest] text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:]) returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -377,15 +466,16 @@ class Llama:
} }
], ],
} }
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "length" finish_reason = "length"
break break
if finish_reason is None:
finish_reason = "length"
if stream: if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -402,16 +492,57 @@ class Llama:
} }
return return
text = text.decode("utf-8") ### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8")
if echo: if echo:
text = prompt + text text_str = prompt + text_str
if suffix is not None: if suffix is not None:
text = text + suffix text_str = text_str + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None: if logprobs is not None:
raise NotImplementedError("logprobs not implemented") text_offset = 0
text_offsets: List[int] = []
token_logprobs: List[float] = []
tokens: List[str] = []
top_logprobs: List[Dict[str, float]] = []
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8") for token in all_tokens
]
all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits
]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
top_logprobs.append(top_logprob)
logprobs_or_none = {
"tokens": tokens,
"text_offset": text_offsets,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
}
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self.ctx) llama_cpp.llama_print_timings(self.ctx)
@ -423,9 +554,9 @@ class Llama:
"model": self.model_path, "model": self.model_path,
"choices": [ "choices": [
{ {
"text": text, "text": text_str,
"index": 0, "index": 0,
"logprobs": None, "logprobs": logprobs_or_none,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
], ],
@ -445,7 +576,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -500,7 +631,7 @@ class Llama:
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -602,12 +733,12 @@ class Llama:
def create_chat_completion( def create_chat_completion(
self, self,
messages: List[ChatCompletionMessage], messages: List[ChatCompletionMessage],
temperature: float = 0.8, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: List[str] = [], stop: Optional[List[str]] = [],
max_tokens: int = 128, max_tokens: int = 256,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -625,13 +756,13 @@ class Llama:
Returns: Returns:
Generated chat completion or a stream of chat completion chunks. Generated chat completion or a stream of chat completion chunks.
""" """
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions.""" stop = stop if stop is not None else []
chat_history = "\n".join( chat_history = "".join(
f'{message["role"]} {message.get("user", "")}: {message["content"]}' f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages for message in messages
) )
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: " PROMPT = chat_history + "### Assistant:"
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "] PROMPT_STOP = ["### Assistant:", "### Human:"]
completion_or_chunks = self( completion_or_chunks = self(
prompt=PROMPT, prompt=PROMPT,
stop=PROMPT_STOP + stop, stop=PROMPT_STOP + stop,
@ -668,8 +799,6 @@ class Llama:
use_mlock=self.params.use_mlock, use_mlock=self.params.use_mlock,
embedding=self.params.embedding, embedding=self.params.embedding,
last_n_tokens_size=self.last_n_tokens_size, last_n_tokens_size=self.last_n_tokens_size,
last_n_tokens_data=self.last_n_tokens_data,
tokens_consumed=self.tokens_consumed,
n_batch=self.n_batch, n_batch=self.n_batch,
n_threads=self.n_threads, n_threads=self.n_threads,
) )
@ -691,9 +820,6 @@ class Llama:
last_n_tokens_size=state["last_n_tokens_size"], last_n_tokens_size=state["last_n_tokens_size"],
verbose=state["verbose"], verbose=state["verbose"],
) )
self.last_n_tokens_data = state["last_n_tokens_data"]
self.tokens_consumed = state["tokens_consumed"]
@staticmethod @staticmethod
def token_eos() -> llama_cpp.llama_token: def token_eos() -> llama_cpp.llama_token:
@ -704,3 +830,7 @@ class Llama:
def token_bos() -> llama_cpp.llama_token: def token_bos() -> llama_cpp.llama_token:
"""Return the beginning-of-sequence token.""" """Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos() return llama_cpp.llama_token_bos()
@staticmethod
def logit_to_logprob(x: float) -> float:
return math.log(1.0 + math.exp(x))

View file

@ -114,6 +114,7 @@ LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(4) # tok_embeddings.weight and output.weight are F16
# Functions # Functions

View file

@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
""" """
import os import os
import json import json
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict from typing_extensions import TypedDict
import llama_cpp import llama_cpp
from fastapi import FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -33,6 +34,8 @@ class Settings(BaseSettings):
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True embedding: bool = True
last_n_tokens_size: int = 64 last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature
app = FastAPI( app = FastAPI(
@ -52,11 +55,21 @@ llama = llama_cpp.Llama(
f16_kv=settings.f16_kv, f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock, use_mlock=settings.use_mlock,
embedding=settings.embedding, embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads, n_threads=settings.n_threads,
n_batch=settings.n_batch, n_batch=settings.n_batch,
n_ctx=settings.n_ctx, n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size, last_n_tokens_size=settings.last_n_tokens_size,
) )
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel): class CreateCompletionRequest(BaseModel):
@ -66,7 +79,7 @@ class CreateCompletionRequest(BaseModel):
temperature: float = 0.8 temperature: float = 0.8
top_p: float = 0.95 top_p: float = 0.95
echo: bool = False echo: bool = False
stop: List[str] = [] stop: Optional[List[str]] = []
stream: bool = False stream: bool = False
# ignored or currently unsupported # ignored or currently unsupported
@ -99,7 +112,9 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
def create_completion(request: CreateCompletionRequest): def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt) request.prompt = "".join(request.prompt)
@ -108,7 +123,6 @@ def create_completion(request: CreateCompletionRequest):
exclude={ exclude={
"model", "model",
"n", "n",
"logprobs",
"frequency_penalty", "frequency_penalty",
"presence_penalty", "presence_penalty",
"best_of", "best_of",
@ -144,7 +158,9 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings", "/v1/embeddings",
response_model=CreateEmbeddingResponse, response_model=CreateEmbeddingResponse,
) )
def create_embedding(request: CreateEmbeddingRequest): def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"})) return llama.create_embedding(**request.dict(exclude={"model", "user"}))
@ -160,7 +176,7 @@ class CreateChatCompletionRequest(BaseModel):
temperature: float = 0.8 temperature: float = 0.8
top_p: float = 0.95 top_p: float = 0.95
stream: bool = False stream: bool = False
stop: List[str] = [] stop: Optional[List[str]] = []
max_tokens: int = 128 max_tokens: int = 128
# ignored or currently unsupported # ignored or currently unsupported
@ -198,6 +214,7 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
) )
def create_chat_completion( def create_chat_completion(
request: CreateChatCompletionRequest, request: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion( completion_or_chunks = llama.create_chat_completion(
**request.dict( **request.dict(

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llama_cpp_python" name = "llama_cpp_python"
version = "0.1.32" version = "0.1.33"
description = "Python bindings for the llama.cpp library" description = "Python bindings for the llama.cpp library"
authors = ["Andrei Betlen <abetlen@gmail.com>"] authors = ["Andrei Betlen <abetlen@gmail.com>"]
license = "MIT" license = "MIT"

View file

@ -10,7 +10,7 @@ setup(
description="A Python wrapper for llama.cpp", description="A Python wrapper for llama.cpp",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
version="0.1.32", version="0.1.33",
author="Andrei Betlen", author="Andrei Betlen",
author_email="abetlen@gmail.com", author_email="abetlen@gmail.com",
license="MIT", license="MIT",

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 3e6e70d8e8917b5bd14c7c9f9b89a585f1ff0b31 Subproject commit e95b6554b493e71a0275764342e09bd5784a7026