diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d7dc625..447acb7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -17,20 +17,23 @@ from typing import ( ) from collections import deque, OrderedDict +import diskcache + from . import llama_cpp from .llama_types import * + class LlamaCache: """Cache for a llama.cpp model.""" - def __init__(self, capacity_bytes: int = (2 << 30)): - self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() + def __init__(self, cache_dir="./llama_cache", capacity_bytes: int = (2 << 30)): + self.cache = diskcache.Cache(cache_dir) self.capacity_bytes = capacity_bytes @property def cache_size(self): - return sum([state.llama_state_size for state in self.cache_state.values()]) + return self.cache.volume() def _find_longest_prefix_key( self, @@ -38,10 +41,8 @@ class LlamaCache: ) -> Optional[Tuple[int, ...]]: min_len = 0 min_key = None - keys = ( - (k, Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() - ) - for k, prefix_len in keys: + for k in self.cache.iterkeys(): + prefix_len = Llama.longest_token_prefix(k, key) if prefix_len > min_len: min_len = prefix_len min_key = k @@ -51,9 +52,9 @@ class LlamaCache: key = tuple(key) _key = self._find_longest_prefix_key(key) if _key is None: - raise KeyError(f"Key not found") - value = self.cache_state[_key] - self.cache_state.move_to_end(_key) + raise KeyError("Key not found") + value = self.cache.pop(_key) + self.cache.push(_key) return value def __contains__(self, key: Sequence[int]) -> bool: @@ -61,11 +62,13 @@ class LlamaCache: def __setitem__(self, key: Sequence[int], value: "LlamaState"): key = tuple(key) - if key in self.cache_state: - del self.cache_state[key] - self.cache_state[key] = value + if key in self.cache: + del self.cache[key] + self.cache[key] = value while self.cache_size > self.capacity_bytes: - self.cache_state.popitem(last=False) + key_to_remove = next(iter(self.cache)) + del self.cache[key_to_remove] + class LlamaState: