Bugfix: Check cache keys as prefix to prompt tokens

This commit is contained in:
Andrei Betlen 2023-04-24 22:18:54 -04:00
parent b75fa96bf7
commit d484c5634e

View file

@ -4,7 +4,7 @@ import uuid
import time import time
import math import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque from collections import deque
from . import llama_cpp from . import llama_cpp
@ -15,15 +15,34 @@ class LlamaCache:
"""Cache for a llama.cpp model.""" """Cache for a llama.cpp model."""
def __init__(self): def __init__(self):
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict() self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
return [
key
for _, key in sorted(
((len(key), key) for key in self.cache_state.keys()), reverse=True
)
]
def _find_key(
self, key: Tuple[llama_cpp.llama_token, ...]
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
for k in self._sorted_keys():
if key[: len(k)] == k:
return k
return None
def __getitem__( def __getitem__(
self, key: Sequence[llama_cpp.llama_token] self, key: Sequence[llama_cpp.llama_token]
) -> Optional["LlamaState"]: ) -> Optional["LlamaState"]:
return self.cache_state.get(tuple(key), None) _key = self._find_key(tuple(key))
if _key is None:
return None
return self.cache_state[_key]
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool: def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return tuple(key) in self.cache_state return self._find_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"): def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry. self.cache_state = dict() # NOTE: Currently limit to one cache entry.
@ -295,7 +314,7 @@ class Llama:
if ( if (
reset reset
and len(self.eval_tokens) > 0 and len(self.eval_tokens) > 0
and self.eval_tokens == tokens[: len(self.eval_tokens)] and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
): ):
if self.verbose: if self.verbose:
print("generate cache hit", file=sys.stderr) print("generate cache hit", file=sys.stderr)
@ -438,6 +457,8 @@ class Llama:
if self.cache and len(completion_tokens) == 0: if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache: if prompt_tokens not in self.cache:
if self.verbose:
print("cache miss", file=sys.stderr)
self.cache[prompt_tokens] = self.save_state() self.cache[prompt_tokens] = self.save_state()
completion_tokens.append(token) completion_tokens.append(token)