Bugfix: Check cache keys as prefix to prompt tokens
This commit is contained in:
parent
b75fa96bf7
commit
d484c5634e
1 changed files with 26 additions and 5 deletions
|
@ -4,7 +4,7 @@ import uuid
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
|
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
|
@ -15,15 +15,34 @@ class LlamaCache:
|
||||||
"""Cache for a llama.cpp model."""
|
"""Cache for a llama.cpp model."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
|
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
|
||||||
|
|
||||||
|
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
|
||||||
|
return [
|
||||||
|
key
|
||||||
|
for _, key in sorted(
|
||||||
|
((len(key), key) for key in self.cache_state.keys()), reverse=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _find_key(
|
||||||
|
self, key: Tuple[llama_cpp.llama_token, ...]
|
||||||
|
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
|
||||||
|
for k in self._sorted_keys():
|
||||||
|
if key[: len(k)] == k:
|
||||||
|
return k
|
||||||
|
return None
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, key: Sequence[llama_cpp.llama_token]
|
self, key: Sequence[llama_cpp.llama_token]
|
||||||
) -> Optional["LlamaState"]:
|
) -> Optional["LlamaState"]:
|
||||||
return self.cache_state.get(tuple(key), None)
|
_key = self._find_key(tuple(key))
|
||||||
|
if _key is None:
|
||||||
|
return None
|
||||||
|
return self.cache_state[_key]
|
||||||
|
|
||||||
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
||||||
return tuple(key) in self.cache_state
|
return self._find_key(tuple(key)) is not None
|
||||||
|
|
||||||
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
||||||
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
|
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
|
||||||
|
@ -295,7 +314,7 @@ class Llama:
|
||||||
if (
|
if (
|
||||||
reset
|
reset
|
||||||
and len(self.eval_tokens) > 0
|
and len(self.eval_tokens) > 0
|
||||||
and self.eval_tokens == tokens[: len(self.eval_tokens)]
|
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
|
||||||
):
|
):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("generate cache hit", file=sys.stderr)
|
print("generate cache hit", file=sys.stderr)
|
||||||
|
@ -438,6 +457,8 @@ class Llama:
|
||||||
|
|
||||||
if self.cache and len(completion_tokens) == 0:
|
if self.cache and len(completion_tokens) == 0:
|
||||||
if prompt_tokens not in self.cache:
|
if prompt_tokens not in self.cache:
|
||||||
|
if self.verbose:
|
||||||
|
print("cache miss", file=sys.stderr)
|
||||||
self.cache[prompt_tokens] = self.save_state()
|
self.cache[prompt_tokens] = self.save_state()
|
||||||
|
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
|
|
Loading…
Reference in a new issue