Add cache implementation using llama state
This commit is contained in:
parent
2c359a28ff
commit
cbe95bbb75
1 changed files with 26 additions and 38 deletions
|
@ -12,12 +12,22 @@ from .llama_types import *
|
|||
|
||||
|
||||
class LlamaCache:
|
||||
"""Cache for a llama.cpp model.
|
||||
"""Cache for a llama.cpp model."""
|
||||
|
||||
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
|
||||
completion. It does not actually cache the results."""
|
||||
def __init__(self):
|
||||
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
|
||||
|
||||
pass
|
||||
def __getitem__(
|
||||
self, key: Sequence[llama_cpp.llama_token]
|
||||
) -> Optional["LlamaState"]:
|
||||
return self.cache_state.get(tuple(key), None)
|
||||
|
||||
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
||||
return tuple(key) in self.cache_state
|
||||
|
||||
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
||||
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
|
||||
self.cache_state[tuple(key)] = value
|
||||
|
||||
|
||||
class LlamaState:
|
||||
|
@ -100,13 +110,7 @@ class Llama:
|
|||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
||||
|
||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||
### saving and restoring state, this allows us to continue a completion if the last
|
||||
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
|
||||
### because it does not take into account stop tokens which have been processed by the model.
|
||||
self._completion_bytes: List[bytes] = []
|
||||
self._cache: Optional[LlamaCache] = None
|
||||
###
|
||||
self.cache: Optional[LlamaCache] = None
|
||||
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
|
||||
|
@ -182,7 +186,7 @@ class Llama:
|
|||
Args:
|
||||
cache: The cache to set.
|
||||
"""
|
||||
self._cache = cache
|
||||
self.cache = cache
|
||||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
|
@ -287,10 +291,9 @@ class Llama:
|
|||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
### HACK
|
||||
|
||||
if (
|
||||
reset
|
||||
and self._cache
|
||||
and len(self.eval_tokens) > 0
|
||||
and self.eval_tokens == tokens[: len(self.eval_tokens)]
|
||||
):
|
||||
|
@ -298,7 +301,7 @@ class Llama:
|
|||
print("generate cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[len(self.eval_tokens) :]
|
||||
###
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
while True:
|
||||
|
@ -415,20 +418,10 @@ class Llama:
|
|||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
### HACK
|
||||
reset: bool = True
|
||||
_prompt: bytes = prompt.encode("utf-8")
|
||||
_completion: bytes = b"".join(self._completion_bytes)
|
||||
if len(_completion) and self._cache and _prompt.startswith(_completion):
|
||||
if self.cache and prompt_tokens in self.cache:
|
||||
if self.verbose:
|
||||
print("completion cache hit", file=sys.stderr)
|
||||
reset = False
|
||||
_prompt = _prompt[len(_completion) :]
|
||||
prompt_tokens = self.tokenize(b" " + _prompt)
|
||||
self._completion_bytes.append(_prompt)
|
||||
else:
|
||||
self._completion_bytes = [prompt.encode("utf-8")]
|
||||
###
|
||||
print("cache hit", file=sys.stderr)
|
||||
self.load_state(self.cache[prompt_tokens])
|
||||
|
||||
finish_reason = "length"
|
||||
for token in self.generate(
|
||||
|
@ -437,12 +430,16 @@ class Llama:
|
|||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
reset=reset,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
self.cache[prompt_tokens] = self.save_state()
|
||||
|
||||
completion_tokens.append(token)
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
|
@ -467,9 +464,6 @@ class Llama:
|
|||
break
|
||||
text = all_text[: len(all_text) - longest]
|
||||
returned_characters += len(text[start:])
|
||||
### HACK
|
||||
self._completion_bytes.append(text[start:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -491,9 +485,6 @@ class Llama:
|
|||
break
|
||||
|
||||
if stream:
|
||||
### HACK
|
||||
self._completion_bytes.append(text[returned_characters:])
|
||||
###
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -510,9 +501,6 @@ class Llama:
|
|||
}
|
||||
return
|
||||
|
||||
### HACK
|
||||
self._completion_bytes.append(text)
|
||||
###
|
||||
text_str = text.decode("utf-8")
|
||||
|
||||
if echo:
|
||||
|
|
Loading…
Reference in a new issue