Add experimental cache
This commit is contained in:
parent
a6372a7ae5
commit
92c077136d
2 changed files with 69 additions and 5 deletions
|
@ -11,6 +11,15 @@ from . import llama_cpp
|
||||||
from .llama_types import *
|
from .llama_types import *
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCache:
|
||||||
|
"""Cache for a llama.cpp model.
|
||||||
|
|
||||||
|
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
||||||
|
completion. It does not actually cache the results."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
"""High-level Python wrapper for a llama.cpp model."""
|
"""High-level Python wrapper for a llama.cpp model."""
|
||||||
|
|
||||||
|
@ -82,6 +91,14 @@ class Llama:
|
||||||
self.n_past = 0
|
self.n_past = 0
|
||||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||||
|
|
||||||
|
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||||
|
### saving and restoring state, this allows us to continue a completion if the last
|
||||||
|
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
||||||
|
### because it does not take into account stop tokens which have been processed by the model.
|
||||||
|
self._completion_bytes: List[bytes] = []
|
||||||
|
self._cache: Optional[LlamaCache] = None
|
||||||
|
###
|
||||||
|
|
||||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
|
@ -135,6 +152,14 @@ class Llama:
|
||||||
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
output += llama_cpp.llama_token_to_str(self.ctx, token)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def set_cache(self, cache: Optional[LlamaCache]):
|
||||||
|
"""Set the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache: The cache to set.
|
||||||
|
"""
|
||||||
|
self._cache = cache
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
self.last_n_tokens_data.extend(
|
self.last_n_tokens_data.extend(
|
||||||
|
@ -245,6 +270,17 @@ class Llama:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
### HACK
|
||||||
|
if (
|
||||||
|
reset
|
||||||
|
and self._cache
|
||||||
|
and len(self.tokens) > 0
|
||||||
|
and self.tokens == tokens[: len(self.tokens)]
|
||||||
|
):
|
||||||
|
if self.verbose:
|
||||||
|
print("generate cache hit", file=sys.stderr)
|
||||||
|
reset = False
|
||||||
|
###
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
while True:
|
while True:
|
||||||
|
@ -361,6 +397,21 @@ class Llama:
|
||||||
"logprobs is not supported for models created with logits_all=False"
|
"logprobs is not supported for models created with logits_all=False"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### HACK
|
||||||
|
reset: bool = True
|
||||||
|
_prompt: bytes = prompt.encode("utf-8")
|
||||||
|
_completion: bytes = b"".join(self._completion_bytes)
|
||||||
|
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
||||||
|
if self.verbose:
|
||||||
|
print("completion cache hit", file=sys.stderr)
|
||||||
|
reset = False
|
||||||
|
_prompt = _prompt[len(_completion) :]
|
||||||
|
prompt_tokens = self.tokenize(b" " + _prompt)
|
||||||
|
self._completion_bytes.append(_prompt)
|
||||||
|
else:
|
||||||
|
self._completion_bytes = [prompt.encode("utf-8")]
|
||||||
|
###
|
||||||
|
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
for token in self.generate(
|
for token in self.generate(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
|
@ -368,6 +419,7 @@ class Llama:
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
|
reset=reset,
|
||||||
):
|
):
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
|
@ -397,6 +449,9 @@ class Llama:
|
||||||
break
|
break
|
||||||
text = all_text[: len(all_text) - longest]
|
text = all_text[: len(all_text) - longest]
|
||||||
returned_characters += len(text[start:])
|
returned_characters += len(text[start:])
|
||||||
|
### HACK
|
||||||
|
self._completion_bytes.append(text[start:])
|
||||||
|
###
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -418,6 +473,9 @@ class Llama:
|
||||||
break
|
break
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
### HACK
|
||||||
|
self._completion_bytes.append(text[returned_characters:])
|
||||||
|
###
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -434,13 +492,16 @@ class Llama:
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
text = text.decode("utf-8")
|
### HACK
|
||||||
|
self._completion_bytes.append(text)
|
||||||
|
###
|
||||||
|
text_str = text.decode("utf-8")
|
||||||
|
|
||||||
if echo:
|
if echo:
|
||||||
text = prompt + text
|
text_str = prompt + text_str
|
||||||
|
|
||||||
if suffix is not None:
|
if suffix is not None:
|
||||||
text = text + suffix
|
text_str = text_str + suffix
|
||||||
|
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
|
@ -493,7 +554,7 @@ class Llama:
|
||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text,
|
"text": text_str,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": logprobs_or_none,
|
"logprobs": logprobs_or_none,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
|
|
@ -35,6 +35,7 @@ class Settings(BaseSettings):
|
||||||
embedding: bool = True
|
embedding: bool = True
|
||||||
last_n_tokens_size: int = 64
|
last_n_tokens_size: int = 64
|
||||||
logits_all: bool = False
|
logits_all: bool = False
|
||||||
|
cache: bool = False # WARNING: This is an experimental feature
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
@ -60,6 +61,9 @@ llama = llama_cpp.Llama(
|
||||||
n_ctx=settings.n_ctx,
|
n_ctx=settings.n_ctx,
|
||||||
last_n_tokens_size=settings.last_n_tokens_size,
|
last_n_tokens_size=settings.last_n_tokens_size,
|
||||||
)
|
)
|
||||||
|
if settings.cache:
|
||||||
|
cache = llama_cpp.LlamaCache()
|
||||||
|
llama.set_cache(cache)
|
||||||
llama_lock = Lock()
|
llama_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +72,6 @@ def get_llama():
|
||||||
yield llama
|
yield llama
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]]
|
prompt: Union[str, List[str]]
|
||||||
suffix: Optional[str] = Field(None)
|
suffix: Optional[str] = Field(None)
|
||||||
|
|
Loading…
Reference in a new issue