Add experimental cache
This commit is contained in:
parent
a6372a7ae5
commit
92c077136d
2 changed files with 69 additions and 5 deletions
|
@ -11,6 +11,15 @@ from . import llama_cpp
|
|||
from .llama_types import *
|
||||
|
||||
|
||||
class LlamaCache:
|
||||
"""Cache for a llama.cpp model.
|
||||
|
||||
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
||||
completion. It does not actually cache the results."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Llama:
|
||||
"""High-level Python wrapper for a llama.cpp model."""
|
||||
|
||||
|
@ -82,6 +91,14 @@ class Llama:
|
|||
self.n_past = 0
|
||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||
|
||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||
### saving and restoring state, this allows us to continue a completion if the last
|
||||
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
||||
### because it does not take into account stop tokens which have been processed by the model.
|
||||
self._completion_bytes: List[bytes] = []
|
||||
self._cache: Optional[LlamaCache] = None
|
||||
###
|
||||
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
|
@ -135,6 +152,14 @@ class Llama:
|
|||
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||
return output
|
||||
|
||||
def set_cache(self, cache: Optional[LlamaCache]):
|
||||
"""Set the cache.
|
||||
|
||||
Args:
|
||||
cache: The cache to set.
|
||||
"""
|
||||
self._cache = cache
|
||||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
self.last_n_tokens_data.extend(
|
||||
|
@ -245,6 +270,17 @@ class Llama:
|
|||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
### HACK
|
||||
if (
|
||||
reset
|
||||
and self._cache
|
||||
and len(self.tokens) > 0
|
||||
and self.tokens == tokens[: len(self.tokens)]
|
||||
):
|
||||
if self.verbose:
|
||||
print("generate cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
###
|
||||
if reset:
|
||||
self.reset()
|
||||
while True:
|
||||
|
@ -361,6 +397,21 @@ class Llama:
|
|||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
### HACK
|
||||
reset: bool = True
|
||||
_prompt: bytes = prompt.encode("utf-8")
|
||||
_completion: bytes = b"".join(self._completion_bytes)
|
||||
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
||||
if self.verbose:
|
||||
print("completion cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
_prompt = _prompt[len(_completion) :]
|
||||
prompt_tokens = self.tokenize(b" " + _prompt)
|
||||
self._completion_bytes.append(_prompt)
|
||||
else:
|
||||
self._completion_bytes = [prompt.encode("utf-8")]
|
||||
###
|
||||
|
||||
finish_reason = "length"
|
||||
for token in self.generate(
|
||||
prompt_tokens,
|
||||
|
@ -368,6 +419,7 @@ class Llama:
|
|||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
reset=reset,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
|
@ -397,6 +449,9 @@ class Llama:
|
|||
break
|
||||
text = all_text[: len(all_text) - longest]
|
||||
returned_characters += len(text[start:])
|
||||
### HACK
|
||||
self._completion_bytes.append(text[start:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -418,6 +473,9 @@ class Llama:
|
|||
break
|
||||
|
||||
if stream:
|
||||
### HACK
|
||||
self._completion_bytes.append(text[returned_characters:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -434,13 +492,16 @@ class Llama:
|
|||
}
|
||||
return
|
||||
|
||||
text = text.decode("utf-8")
|
||||
### HACK
|
||||
self._completion_bytes.append(text)
|
||||
###
|
||||
text_str = text.decode("utf-8")
|
||||
|
||||
if echo:
|
||||
text = prompt + text
|
||||
text_str = prompt + text_str
|
||||
|
||||
if suffix is not None:
|
||||
text = text + suffix
|
||||
text_str = text_str + suffix
|
||||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
|
@ -493,7 +554,7 @@ class Llama:
|
|||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": text,
|
||||
"text": text_str,
|
||||
"index": 0,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": finish_reason,
|
||||
|
|
|
@ -35,6 +35,7 @@ class Settings(BaseSettings):
|
|||
embedding: bool = True
|
||||
last_n_tokens_size: int = 64
|
||||
logits_all: bool = False
|
||||
cache: bool = False # WARNING: This is an experimental feature
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
|
@ -60,6 +61,9 @@ llama = llama_cpp.Llama(
|
|||
n_ctx=settings.n_ctx,
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
)
|
||||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache()
|
||||
llama.set_cache(cache)
|
||||
llama_lock = Lock()
|
||||
|
||||
|
||||
|
@ -68,7 +72,6 @@ def get_llama():
|
|||
yield llama
|
||||
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]]
|
||||
suffix: Optional[str] = Field(None)
|
||||
|
|
Loading…
Add table
Reference in a new issue