Add in-memory longest prefix cache. Closes #158

This commit is contained in:
Andrei Betlen 2023-05-07 19:31:26 -04:00
parent 8dfde63255
commit 0e94a70de1

View file

@ -5,7 +5,7 @@ import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque
from collections import deque, OrderedDict
from . import llama_cpp
from .llama_types import *
@ -14,37 +14,50 @@ from .llama_types import *
class LlamaCache:
"""Cache for a llama.cpp model."""
def __init__(self):
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
def __init__(self, capacity_bytes: int = (2 << 30)):
self.cache_state: OrderedDict[
Tuple[llama_cpp.llama_token, ...], "LlamaState"
] = OrderedDict()
self.capacity_bytes = capacity_bytes
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
return [
key
for _, key in sorted(
((len(key), key) for key in self.cache_state.keys()), reverse=True
)
]
@property
def cache_size(self):
return sum([state.llama_state_size for state in self.cache_state.values()])
def _find_key(
self, key: Tuple[llama_cpp.llama_token, ...]
def _find_longest_prefix_key(
self,
key: Tuple[llama_cpp.llama_token, ...],
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
for k in self._sorted_keys():
if key[: len(k)] == k:
return k
return None
min_len = 0
min_key = None
keys = (
(k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
)
for k, prefix_len in keys:
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
_key = self._find_key(tuple(key))
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError(f"Key not found: {key}")
return self.cache_state[_key]
raise KeyError(f"Key not found")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return self._find_key(tuple(key)) is not None
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
self.cache_state[tuple(key)] = value
key = tuple(key)
if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
while self.cache_size > self.capacity_bytes:
self.cache_state.popitem(last=False)
class LlamaState:
@ -53,7 +66,7 @@ class LlamaState:
eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[float]],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state_size: llama_cpp.c_size_t,
llama_state_size: int,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
@ -526,10 +539,22 @@ class Llama:
"logprobs is not supported for models created with logits_all=False"
)
if self.cache and prompt_tokens in self.cache:
if self.verbose:
print("Llama._create_completion: cache hit", file=sys.stderr)
self.load_state(self.cache[prompt_tokens])
if self.cache:
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.eval_tokens, prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self.eval_tokens, prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
if self.verbose:
print("Llama._create_completion: cache hit", file=sys.stderr)
except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)
finish_reason = "length"
multibyte_fix = 0
@ -1004,3 +1029,15 @@ class Llama:
exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps)
return [math.log(x / sum_exps) for x in exps]
@staticmethod
def longest_token_prefix(
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
):
longest_prefix = 0
for _a, _b in zip(a, b):
if _a == _b:
longest_prefix += 1
else:
break
return longest_prefix