Bugfix: Check cache keys as prefix to prompt tokens
This commit is contained in:
parent
b75fa96bf7
commit
d484c5634e
1 changed files with 26 additions and 5 deletions
|
@ -4,7 +4,7 @@ import uuid
|
|||
import time
|
||||
import math
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
||||
from collections import deque
|
||||
|
||||
from . import llama_cpp
|
||||
|
@ -15,15 +15,34 @@ class LlamaCache:
|
|||
"""Cache for a llama.cpp model."""
|
||||
|
||||
def __init__(self):
|
||||
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
|
||||
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
|
||||
|
||||
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
|
||||
return [
|
||||
key
|
||||
for _, key in sorted(
|
||||
((len(key), key) for key in self.cache_state.keys()), reverse=True
|
||||
)
|
||||
]
|
||||
|
||||
def _find_key(
|
||||
self, key: Tuple[llama_cpp.llama_token, ...]
|
||||
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
|
||||
for k in self._sorted_keys():
|
||||
if key[: len(k)] == k:
|
||||
return k
|
||||
return None
|
||||
|
||||
def __getitem__(
|
||||
self, key: Sequence[llama_cpp.llama_token]
|
||||
) -> Optional["LlamaState"]:
|
||||
return self.cache_state.get(tuple(key), None)
|
||||
_key = self._find_key(tuple(key))
|
||||
if _key is None:
|
||||
return None
|
||||
return self.cache_state[_key]
|
||||
|
||||
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
|
||||
return tuple(key) in self.cache_state
|
||||
return self._find_key(tuple(key)) is not None
|
||||
|
||||
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
|
||||
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
|
||||
|
@ -295,7 +314,7 @@ class Llama:
|
|||
if (
|
||||
reset
|
||||
and len(self.eval_tokens) > 0
|
||||
and self.eval_tokens == tokens[: len(self.eval_tokens)]
|
||||
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
|
||||
):
|
||||
if self.verbose:
|
||||
print("generate cache hit", file=sys.stderr)
|
||||
|
@ -438,6 +457,8 @@ class Llama:
|
|||
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
if self.verbose:
|
||||
print("cache miss", file=sys.stderr)
|
||||
self.cache[prompt_tokens] = self.save_state()
|
||||
|
||||
completion_tokens.append(token)
|
||||
|
|
Loading…
Reference in a new issue