diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6836ea5..de06da0 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -5,7 +5,7 @@ import time import math import multiprocessing from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple -from collections import deque +from collections import deque, OrderedDict from . import llama_cpp from .llama_types import * @@ -14,37 +14,50 @@ from .llama_types import * class LlamaCache: """Cache for a llama.cpp model.""" - def __init__(self): - self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict() + def __init__(self, capacity_bytes: int = (2 << 30)): + self.cache_state: OrderedDict[ + Tuple[llama_cpp.llama_token, ...], "LlamaState" + ] = OrderedDict() + self.capacity_bytes = capacity_bytes - def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]: - return [ - key - for _, key in sorted( - ((len(key), key) for key in self.cache_state.keys()), reverse=True - ) - ] + @property + def cache_size(self): + return sum([state.llama_state_size for state in self.cache_state.values()]) - def _find_key( - self, key: Tuple[llama_cpp.llama_token, ...] + def _find_longest_prefix_key( + self, + key: Tuple[llama_cpp.llama_token, ...], ) -> Optional[Tuple[llama_cpp.llama_token, ...]]: - for k in self._sorted_keys(): - if key[: len(k)] == k: - return k - return None + min_len = 0 + min_key = None + keys = ( + (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() + ) + for k, prefix_len in keys: + if prefix_len > min_len: + min_len = prefix_len + min_key = k + return min_key def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState": - _key = self._find_key(tuple(key)) + key = tuple(key) + _key = self._find_longest_prefix_key(key) if _key is None: - raise KeyError(f"Key not found: {key}") - return self.cache_state[_key] + raise KeyError(f"Key not found") + value = self.cache_state[_key] + self.cache_state.move_to_end(_key) + return value def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool: - return self._find_key(tuple(key)) is not None + return self._find_longest_prefix_key(tuple(key)) is not None def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"): - self.cache_state = dict() # NOTE: Currently limit to one cache entry. - self.cache_state[tuple(key)] = value + key = tuple(key) + if key in self.cache_state: + del self.cache_state[key] + self.cache_state[key] = value + while self.cache_size > self.capacity_bytes: + self.cache_state.popitem(last=False) class LlamaState: @@ -53,7 +66,7 @@ class LlamaState: eval_tokens: Deque[llama_cpp.llama_token], eval_logits: Deque[List[float]], llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] - llama_state_size: llama_cpp.c_size_t, + llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits @@ -526,10 +539,22 @@ class Llama: "logprobs is not supported for models created with logits_all=False" ) - if self.cache and prompt_tokens in self.cache: - if self.verbose: - print("Llama._create_completion: cache hit", file=sys.stderr) - self.load_state(self.cache[prompt_tokens]) + if self.cache: + try: + cache_item = self.cache[prompt_tokens] + cache_prefix_len = Llama.longest_token_prefix( + cache_item.eval_tokens, prompt_tokens + ) + eval_prefix_len = Llama.longest_token_prefix( + self.eval_tokens, prompt_tokens + ) + if cache_prefix_len > eval_prefix_len: + self.load_state(cache_item) + if self.verbose: + print("Llama._create_completion: cache hit", file=sys.stderr) + except KeyError: + if self.verbose: + print("Llama._create_completion: cache miss", file=sys.stderr) finish_reason = "length" multibyte_fix = 0 @@ -1004,3 +1029,15 @@ class Llama: exps = [math.exp(float(x)) for x in logits] sum_exps = sum(exps) return [math.log(x / sum_exps) for x in exps] + + @staticmethod + def longest_token_prefix( + a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token] + ): + longest_prefix = 0 + for _a, _b in zip(a, b): + if _a == _b: + longest_prefix += 1 + else: + break + return longest_prefix