Merge branch 'main' of github.com:abetlen/llama-cpp-python
This commit is contained in:
commit
6df27b2da0
6 changed files with 194 additions and 46 deletions
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import math
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||
from collections import deque
|
||||
|
@ -10,6 +11,15 @@ from . import llama_cpp
|
|||
from .llama_types import *
|
||||
|
||||
|
||||
class LlamaCache:
|
||||
"""Cache for a llama.cpp model.
|
||||
|
||||
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
||||
completion. It does not actually cache the results."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Llama:
|
||||
"""High-level Python wrapper for a llama.cpp model."""
|
||||
|
||||
|
@ -20,7 +30,7 @@ class Llama:
|
|||
n_ctx: int = 512,
|
||||
n_parts: int = -1,
|
||||
seed: int = 1337,
|
||||
f16_kv: bool = False,
|
||||
f16_kv: bool = True,
|
||||
logits_all: bool = False,
|
||||
vocab_only: bool = False,
|
||||
use_mmap: bool = True,
|
||||
|
@ -75,7 +85,19 @@ class Llama:
|
|||
maxlen=self.last_n_tokens_size,
|
||||
)
|
||||
self.tokens_consumed = 0
|
||||
self.tokens: List[llama_cpp.llama_token] = []
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.n_tokens = 0
|
||||
self.n_past = 0
|
||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||
|
||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||
### saving and restoring state, this allows us to continue a completion if the last
|
||||
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
||||
### because it does not take into account stop tokens which have been processed by the model.
|
||||
self._completion_bytes: List[bytes] = []
|
||||
self._cache: Optional[LlamaCache] = None
|
||||
###
|
||||
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
|
||||
|
@ -130,12 +152,24 @@ class Llama:
|
|||
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||
return output
|
||||
|
||||
def set_cache(self, cache: Optional[LlamaCache]):
|
||||
"""Set the cache.
|
||||
|
||||
Args:
|
||||
cache: The cache to set.
|
||||
"""
|
||||
self._cache = cache
|
||||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
self.last_n_tokens_data.extend(
|
||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
||||
)
|
||||
self.tokens_consumed = 0
|
||||
self.tokens.clear()
|
||||
self.n_tokens = 0
|
||||
self.n_past = 0
|
||||
self.all_logits.clear()
|
||||
|
||||
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
||||
"""Evaluate a list of tokens.
|
||||
|
@ -147,18 +181,32 @@ class Llama:
|
|||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||
for i in range(0, len(tokens), self.n_batch):
|
||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||
n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||
self.n_tokens = len(batch)
|
||||
return_code = llama_cpp.llama_eval(
|
||||
ctx=self.ctx,
|
||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||
n_tokens=llama_cpp.c_int(len(batch)),
|
||||
n_past=llama_cpp.c_int(n_past),
|
||||
n_tokens=llama_cpp.c_int(self.n_tokens),
|
||||
n_past=llama_cpp.c_int(self.n_past),
|
||||
n_threads=llama_cpp.c_int(self.n_threads),
|
||||
)
|
||||
if int(return_code) != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
self.tokens.extend(batch)
|
||||
self.last_n_tokens_data.extend(batch)
|
||||
self.tokens_consumed += len(batch)
|
||||
if self.params.logits_all:
|
||||
self.all_logits.extend(self._logits())
|
||||
|
||||
def _logits(self) -> List[List[float]]:
|
||||
"""Return the logits from the last call to llama_eval."""
|
||||
assert self.ctx is not None
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
rows = self.n_tokens if self.params.logits_all else 1
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
@ -198,6 +246,7 @@ class Llama:
|
|||
top_p: float,
|
||||
temp: float,
|
||||
repeat_penalty: float,
|
||||
reset: bool = True,
|
||||
) -> Generator[
|
||||
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
||||
]:
|
||||
|
@ -215,12 +264,25 @@ class Llama:
|
|||
top_p: The top-p sampling parameter.
|
||||
temp: The temperature parameter.
|
||||
repeat_penalty: The repeat penalty parameter.
|
||||
reset: Whether to reset the model state.
|
||||
|
||||
Yields:
|
||||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
self.reset()
|
||||
### HACK
|
||||
if (
|
||||
reset
|
||||
and self._cache
|
||||
and len(self.tokens) > 0
|
||||
and self.tokens == tokens[: len(self.tokens)]
|
||||
):
|
||||
if self.verbose:
|
||||
print("generate cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
###
|
||||
if reset:
|
||||
self.reset()
|
||||
while True:
|
||||
self.eval(tokens)
|
||||
token = self.sample(
|
||||
|
@ -300,19 +362,22 @@ class Llama:
|
|||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: List[str] = [],
|
||||
stop: Optional[List[str]] = [],
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk],]:
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||
assert self.ctx is not None
|
||||
completion_id = f"cmpl-{str(uuid.uuid4())}"
|
||||
created = int(time.time())
|
||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||
created: int = int(time.time())
|
||||
completion_tokens: List[llama_cpp.llama_token] = []
|
||||
# Add blank space to start of prompt to match OG llama tokenizer
|
||||
prompt_tokens = self.tokenize(b" " + prompt.encode("utf-8"))
|
||||
text = b""
|
||||
returned_characters = 0
|
||||
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
|
||||
b" " + prompt.encode("utf-8")
|
||||
)
|
||||
text: bytes = b""
|
||||
returned_characters: int = 0
|
||||
stop = stop if stop is not None else []
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
@ -327,13 +392,34 @@ class Llama:
|
|||
else:
|
||||
stop_sequences = []
|
||||
|
||||
finish_reason = None
|
||||
if logprobs is not None and self.params.logits_all is False:
|
||||
raise ValueError(
|
||||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
### HACK
|
||||
reset: bool = True
|
||||
_prompt: bytes = prompt.encode("utf-8")
|
||||
_completion: bytes = b"".join(self._completion_bytes)
|
||||
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
||||
if self.verbose:
|
||||
print("completion cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
_prompt = _prompt[len(_completion) :]
|
||||
prompt_tokens = self.tokenize(b" " + _prompt)
|
||||
self._completion_bytes.append(_prompt)
|
||||
else:
|
||||
self._completion_bytes = [prompt.encode("utf-8")]
|
||||
###
|
||||
|
||||
finish_reason = "length"
|
||||
for token in self.generate(
|
||||
prompt_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
reset=reset,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
|
@ -363,6 +449,9 @@ class Llama:
|
|||
break
|
||||
text = all_text[: len(all_text) - longest]
|
||||
returned_characters += len(text[start:])
|
||||
### HACK
|
||||
self._completion_bytes.append(text[start:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -377,15 +466,16 @@ class Llama:
|
|||
}
|
||||
],
|
||||
}
|
||||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "length"
|
||||
break
|
||||
|
||||
if finish_reason is None:
|
||||
finish_reason = "length"
|
||||
|
||||
if stream:
|
||||
### HACK
|
||||
self._completion_bytes.append(text[returned_characters:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -402,16 +492,57 @@ class Llama:
|
|||
}
|
||||
return
|
||||
|
||||
text = text.decode("utf-8")
|
||||
### HACK
|
||||
self._completion_bytes.append(text)
|
||||
###
|
||||
text_str = text.decode("utf-8")
|
||||
|
||||
if echo:
|
||||
text = prompt + text
|
||||
text_str = prompt + text_str
|
||||
|
||||
if suffix is not None:
|
||||
text = text + suffix
|
||||
text_str = text_str + suffix
|
||||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
raise NotImplementedError("logprobs not implemented")
|
||||
text_offset = 0
|
||||
text_offsets: List[int] = []
|
||||
token_logprobs: List[float] = []
|
||||
tokens: List[str] = []
|
||||
top_logprobs: List[Dict[str, float]] = []
|
||||
|
||||
all_tokens = prompt_tokens + completion_tokens
|
||||
all_token_strs = [
|
||||
self.detokenize([token]).decode("utf-8") for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
for row in self.all_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
text_offsets.append(text_offset)
|
||||
text_offset += len(token_str)
|
||||
tokens.append(token_str)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||
top_logprobs.append(top_logprob)
|
||||
logprobs_or_none = {
|
||||
"tokens": tokens,
|
||||
"text_offset": text_offsets,
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
}
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self.ctx)
|
||||
|
@ -423,9 +554,9 @@ class Llama:
|
|||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": text,
|
||||
"text": text_str,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
|
@ -445,7 +576,7 @@ class Llama:
|
|||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: List[str] = [],
|
||||
stop: Optional[List[str]] = [],
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
|
@ -500,7 +631,7 @@ class Llama:
|
|||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: List[str] = [],
|
||||
stop: Optional[List[str]] = [],
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
|
@ -602,12 +733,12 @@ class Llama:
|
|||
def create_chat_completion(
|
||||
self,
|
||||
messages: List[ChatCompletionMessage],
|
||||
temperature: float = 0.8,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: List[str] = [],
|
||||
max_tokens: int = 128,
|
||||
stop: Optional[List[str]] = [],
|
||||
max_tokens: int = 256,
|
||||
repeat_penalty: float = 1.1,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
"""Generate a chat completion from a list of messages.
|
||||
|
@ -625,13 +756,13 @@ class Llama:
|
|||
Returns:
|
||||
Generated chat completion or a stream of chat completion chunks.
|
||||
"""
|
||||
instructions = """Complete the following chat conversation between the user and the assistant. System messages should be strictly followed as additional instructions."""
|
||||
chat_history = "\n".join(
|
||||
f'{message["role"]} {message.get("user", "")}: {message["content"]}'
|
||||
stop = stop if stop is not None else []
|
||||
chat_history = "".join(
|
||||
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
|
||||
for message in messages
|
||||
)
|
||||
PROMPT = f" \n\n### Instructions:{instructions}\n\n### Inputs:{chat_history}\n\n### Response:\nassistant: "
|
||||
PROMPT_STOP = ["###", "\nuser: ", "\nassistant: ", "\nsystem: "]
|
||||
PROMPT = chat_history + "### Assistant:"
|
||||
PROMPT_STOP = ["### Assistant:", "### Human:"]
|
||||
completion_or_chunks = self(
|
||||
prompt=PROMPT,
|
||||
stop=PROMPT_STOP + stop,
|
||||
|
@ -668,8 +799,6 @@ class Llama:
|
|||
use_mlock=self.params.use_mlock,
|
||||
embedding=self.params.embedding,
|
||||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
last_n_tokens_data=self.last_n_tokens_data,
|
||||
tokens_consumed=self.tokens_consumed,
|
||||
n_batch=self.n_batch,
|
||||
n_threads=self.n_threads,
|
||||
)
|
||||
|
@ -691,9 +820,6 @@ class Llama:
|
|||
last_n_tokens_size=state["last_n_tokens_size"],
|
||||
verbose=state["verbose"],
|
||||
)
|
||||
self.last_n_tokens_data = state["last_n_tokens_data"]
|
||||
self.tokens_consumed = state["tokens_consumed"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def token_eos() -> llama_cpp.llama_token:
|
||||
|
@ -704,3 +830,7 @@ class Llama:
|
|||
def token_bos() -> llama_cpp.llama_token:
|
||||
"""Return the beginning-of-sequence token."""
|
||||
return llama_cpp.llama_token_bos()
|
||||
|
||||
@staticmethod
|
||||
def logit_to_logprob(x: float) -> float:
|
||||
return math.log(1.0 + math.exp(x))
|
||||
|
|
|
@ -114,6 +114,7 @@ LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
|
|||
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = ctypes.c_int(4) # tok_embeddings.weight and output.weight are F16
|
||||
|
||||
# Functions
|
||||
|
||||
|
|
|
@ -13,12 +13,13 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
|||
"""
|
||||
import os
|
||||
import json
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Literal, Union, Iterator, Dict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
@ -33,6 +34,8 @@ class Settings(BaseSettings):
|
|||
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
|
||||
embedding: bool = True
|
||||
last_n_tokens_size: int = 64
|
||||
logits_all: bool = False
|
||||
cache: bool = False # WARNING: This is an experimental feature
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
|
@ -52,11 +55,21 @@ llama = llama_cpp.Llama(
|
|||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
embedding=settings.embedding,
|
||||
logits_all=settings.logits_all,
|
||||
n_threads=settings.n_threads,
|
||||
n_batch=settings.n_batch,
|
||||
n_ctx=settings.n_ctx,
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
)
|
||||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache()
|
||||
llama.set_cache(cache)
|
||||
llama_lock = Lock()
|
||||
|
||||
|
||||
def get_llama():
|
||||
with llama_lock:
|
||||
yield llama
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
|
@ -66,7 +79,7 @@ class CreateCompletionRequest(BaseModel):
|
|||
temperature: float = 0.8
|
||||
top_p: float = 0.95
|
||||
echo: bool = False
|
||||
stop: List[str] = []
|
||||
stop: Optional[List[str]] = []
|
||||
stream: bool = False
|
||||
|
||||
# ignored or currently unsupported
|
||||
|
@ -99,7 +112,9 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
|||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
def create_completion(request: CreateCompletionRequest):
|
||||
def create_completion(
|
||||
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||
):
|
||||
if isinstance(request.prompt, list):
|
||||
request.prompt = "".join(request.prompt)
|
||||
|
||||
|
@ -108,7 +123,6 @@ def create_completion(request: CreateCompletionRequest):
|
|||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"logprobs",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"best_of",
|
||||
|
@ -144,7 +158,9 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
|||
"/v1/embeddings",
|
||||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
def create_embedding(request: CreateEmbeddingRequest):
|
||||
def create_embedding(
|
||||
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||
):
|
||||
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
||||
|
||||
|
||||
|
@ -160,7 +176,7 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
temperature: float = 0.8
|
||||
top_p: float = 0.95
|
||||
stream: bool = False
|
||||
stop: List[str] = []
|
||||
stop: Optional[List[str]] = []
|
||||
max_tokens: int = 128
|
||||
|
||||
# ignored or currently unsupported
|
||||
|
@ -198,6 +214,7 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
|
|||
)
|
||||
def create_chat_completion(
|
||||
request: CreateChatCompletionRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
|
||||
completion_or_chunks = llama.create_chat_completion(
|
||||
**request.dict(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "llama_cpp_python"
|
||||
version = "0.1.32"
|
||||
version = "0.1.33"
|
||||
description = "Python bindings for the llama.cpp library"
|
||||
authors = ["Andrei Betlen <abetlen@gmail.com>"]
|
||||
license = "MIT"
|
||||
|
|
2
setup.py
2
setup.py
|
@ -10,7 +10,7 @@ setup(
|
|||
description="A Python wrapper for llama.cpp",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
version="0.1.32",
|
||||
version="0.1.33",
|
||||
author="Andrei Betlen",
|
||||
author_email="abetlen@gmail.com",
|
||||
license="MIT",
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 3e6e70d8e8917b5bd14c7c9f9b89a585f1ff0b31
|
||||
Subproject commit e95b6554b493e71a0275764342e09bd5784a7026
|
Loading…
Reference in a new issue