Bugfix: Check cache keys as prefix to prompt tokens

This commit is contained in:
Andrei Betlen 2023-04-24 22:18:54 -04:00
parent b75fa96bf7
commit d484c5634e

View file

@ -4,7 +4,7 @@ import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque
from . import llama_cpp
@ -15,15 +15,34 @@ class LlamaCache:
"""Cache for a llama.cpp model."""
def __init__(self):
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
return [
key
for _, key in sorted(
((len(key), key) for key in self.cache_state.keys()), reverse=True
)
]
def _find_key(
self, key: Tuple[llama_cpp.llama_token, ...]
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
for k in self._sorted_keys():
if key[: len(k)] == k:
return k
return None
def __getitem__(
self, key: Sequence[llama_cpp.llama_token]
) -> Optional["LlamaState"]:
return self.cache_state.get(tuple(key), None)
_key = self._find_key(tuple(key))
if _key is None:
return None
return self.cache_state[_key]
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return tuple(key) in self.cache_state
return self._find_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
@ -295,7 +314,7 @@ class Llama:
if (
reset
and len(self.eval_tokens) > 0
and self.eval_tokens == tokens[: len(self.eval_tokens)]
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
):
if self.verbose:
print("generate cache hit", file=sys.stderr)
@ -438,6 +457,8 @@ class Llama:
if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache:
if self.verbose:
print("cache miss", file=sys.stderr)
self.cache[prompt_tokens] = self.save_state()
completion_tokens.append(token)