Add in-memory longest prefix cache. Closes #158

This commit is contained in:
Andrei Betlen 2023-05-07 19:31:26 -04:00
parent 8dfde63255
commit 0e94a70de1

View file

@ -5,7 +5,7 @@ import time
import math import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque from collections import deque, OrderedDict
from . import llama_cpp from . import llama_cpp
from .llama_types import * from .llama_types import *
@ -14,37 +14,50 @@ from .llama_types import *
class LlamaCache: class LlamaCache:
"""Cache for a llama.cpp model.""" """Cache for a llama.cpp model."""
def __init__(self): def __init__(self, capacity_bytes: int = (2 << 30)):
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict() self.cache_state: OrderedDict[
Tuple[llama_cpp.llama_token, ...], "LlamaState"
] = OrderedDict()
self.capacity_bytes = capacity_bytes
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]: @property
return [ def cache_size(self):
key return sum([state.llama_state_size for state in self.cache_state.values()])
for _, key in sorted(
((len(key), key) for key in self.cache_state.keys()), reverse=True
)
]
def _find_key( def _find_longest_prefix_key(
self, key: Tuple[llama_cpp.llama_token, ...] self,
key: Tuple[llama_cpp.llama_token, ...],
) -> Optional[Tuple[llama_cpp.llama_token, ...]]: ) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
for k in self._sorted_keys(): min_len = 0
if key[: len(k)] == k: min_key = None
return k keys = (
return None (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys()
)
for k, prefix_len in keys:
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState": def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
_key = self._find_key(tuple(key)) key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None: if _key is None:
raise KeyError(f"Key not found: {key}") raise KeyError(f"Key not found")
return self.cache_state[_key] value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool: def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
return self._find_key(tuple(key)) is not None return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"): def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
self.cache_state = dict() # NOTE: Currently limit to one cache entry. key = tuple(key)
self.cache_state[tuple(key)] = value if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
while self.cache_size > self.capacity_bytes:
self.cache_state.popitem(last=False)
class LlamaState: class LlamaState:
@ -53,7 +66,7 @@ class LlamaState:
eval_tokens: Deque[llama_cpp.llama_token], eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[float]], eval_logits: Deque[List[float]],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state_size: llama_cpp.c_size_t, llama_state_size: int,
): ):
self.eval_tokens = eval_tokens self.eval_tokens = eval_tokens
self.eval_logits = eval_logits self.eval_logits = eval_logits
@ -526,10 +539,22 @@ class Llama:
"logprobs is not supported for models created with logits_all=False" "logprobs is not supported for models created with logits_all=False"
) )
if self.cache and prompt_tokens in self.cache: if self.cache:
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.eval_tokens, prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self.eval_tokens, prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
if self.verbose: if self.verbose:
print("Llama._create_completion: cache hit", file=sys.stderr) print("Llama._create_completion: cache hit", file=sys.stderr)
self.load_state(self.cache[prompt_tokens]) except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)
finish_reason = "length" finish_reason = "length"
multibyte_fix = 0 multibyte_fix = 0
@ -1004,3 +1029,15 @@ class Llama:
exps = [math.exp(float(x)) for x in logits] exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps) sum_exps = sum(exps)
return [math.log(x / sum_exps) for x in exps] return [math.log(x / sum_exps) for x in exps]
@staticmethod
def longest_token_prefix(
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
):
longest_prefix = 0
for _a, _b in zip(a, b):
if _a == _b:
longest_prefix += 1
else:
break
return longest_prefix