Add cache implementation using llama state

This commit is contained in:
Andrei Betlen 2023-04-24 19:54:41 -04:00
parent 2c359a28ff
commit cbe95bbb75

View file

@ -12,12 +12,22 @@ from .llama_types import *
class LlamaCache:
"""Cache for a llama.cpp model.
"""Cache for a llama.cpp model."""
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
completion. It does not actually cache the results."""
def __init__(self):
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
pass
def __getitem__(
self, key: Sequence[llama_cpp.llama_token]
) -> Optional["LlamaState"]:
return self.cache_state.get(tuple(key), None)
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return tuple(key) in self.cache_state
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
self.cache_state[tuple(key)] = value
class LlamaState:
@ -100,13 +110,7 @@ class Llama:
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
### because it does not take into account stop tokens which have been processed by the model.
self._completion_bytes: List[bytes] = []
self._cache: Optional[LlamaCache] = None
###
self.cache: Optional[LlamaCache] = None
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
@ -182,7 +186,7 @@ class Llama:
Args:
cache: The cache to set.
"""
self._cache = cache
self.cache = cache
def reset(self):
"""Reset the model state."""
@ -287,10 +291,9 @@ class Llama:
The generated tokens.
"""
assert self.ctx is not None
### HACK
if (
reset
and self._cache
and len(self.eval_tokens) > 0
and self.eval_tokens == tokens[: len(self.eval_tokens)]
):
@ -298,7 +301,7 @@ class Llama:
print("generate cache hit", file=sys.stderr)
reset = False
tokens = tokens[len(self.eval_tokens) :]
###
if reset:
self.reset()
while True:
@ -415,20 +418,10 @@ class Llama:
"logprobs is not supported for models created with logits_all=False"
)
### HACK
reset: bool = True
_prompt: bytes = prompt.encode("utf-8")
_completion: bytes = b"".join(self._completion_bytes)
if len(_completion) and self._cache and _prompt.startswith(_completion):
if self.cache and prompt_tokens in self.cache:
if self.verbose:
print("completion cache hit", file=sys.stderr)
reset = False
_prompt = _prompt[len(_completion) :]
prompt_tokens = self.tokenize(b" " + _prompt)
self._completion_bytes.append(_prompt)
else:
self._completion_bytes = [prompt.encode("utf-8")]
###
print("cache hit", file=sys.stderr)
self.load_state(self.cache[prompt_tokens])
finish_reason = "length"
for token in self.generate(
@ -437,12 +430,16 @@ class Llama:
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
reset=reset,
):
if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache:
self.cache[prompt_tokens] = self.save_state()
completion_tokens.append(token)
all_text = self.detokenize(completion_tokens)
@ -467,9 +464,6 @@ class Llama:
break
text = all_text[: len(all_text) - longest]
returned_characters += len(text[start:])
### HACK
self._completion_bytes.append(text[start:])
###
yield {
"id": completion_id,
"object": "text_completion",
@ -491,9 +485,6 @@ class Llama:
break
if stream:
### HACK
self._completion_bytes.append(text[returned_characters:])
###
yield {
"id": completion_id,
"object": "text_completion",
@ -510,9 +501,6 @@ class Llama:
}
return
### HACK
self._completion_bytes.append(text)
###
text_str = text.decode("utf-8")
if echo: