From 0c421685084415229ae2689fb399ce55e73d9daf Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 8 Jun 2023 13:19:23 -0400 Subject: [PATCH] Fix cache implementation breaking changes --- llama_cpp/llama.py | 450 +++++++++++++++++++++------------------- llama_cpp/server/app.py | 9 + 2 files changed, 247 insertions(+), 212 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b88fd65..05994b6 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,7 @@ import uuid import time import math import multiprocessing -from abc import ABC +from abc import ABC, abstractmethod from typing import ( List, Optional, @@ -27,33 +27,37 @@ import numpy as np import numpy.typing as npt -class LlamaCache(ABC): +class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" def __init__(self, capacity_bytes: int = (2 << 30)): - pass + self.capacity_bytes = capacity_bytes @property - def cache_size(self): - return 0 + @abstractmethod + def cache_size(self) -> int: + raise NotImplementedError def _find_longest_prefix_key( - self, - key: Tuple[int, ...], + self, + key: Tuple[int, ...], ) -> Optional[Tuple[int, ...]]: pass + @abstractmethod def __getitem__(self, key: Sequence[int]) -> "LlamaState": - pass + raise NotImplementedError + @abstractmethod def __contains__(self, key: Sequence[int]) -> bool: - pass + raise NotImplementedError - def __setitem__(self, key: Sequence[int], value: "LlamaState"): - pass + @abstractmethod + def __setitem__(self, key: Sequence[int], value: "LlamaState") -> None: + raise NotImplementedError -class LlamaRAMCache(LlamaCache): +class LlamaRAMCache(BaseLlamaCache): """Cache for a llama.cpp model using RAM.""" def __init__(self, capacity_bytes: int = (2 << 30)): @@ -66,8 +70,8 @@ class LlamaRAMCache(LlamaCache): return sum([state.llama_state_size for state in self.cache_state.values()]) def _find_longest_prefix_key( - self, - key: Tuple[int, ...], + self, + key: Tuple[int, ...], ) -> Optional[Tuple[int, ...]]: min_len = 0 min_key = None @@ -97,32 +101,38 @@ class LlamaRAMCache(LlamaCache): if key in self.cache_state: del self.cache_state[key] self.cache_state[key] = value - while self.cache_size > self.capacity_bytes: + while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: self.cache_state.popitem(last=False) -class LlamaDiskCache(LlamaCache): +# Alias for backwards compatibility +LlamaCache = LlamaRAMCache + + +class LlamaDiskCache(BaseLlamaCache): """Cache for a llama.cpp model using disk.""" - def __init__(self, cache_dir="./llama_cache", capacity_bytes: int = (2 << 30)): + def __init__( + self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) + ): super().__init__(capacity_bytes) self.cache = diskcache.Cache(cache_dir) @property def cache_size(self): - return self.cache.volume() + return int(self.cache.volume()) # type: ignore def _find_longest_prefix_key( - self, - key: Tuple[int, ...], + self, + key: Tuple[int, ...], ) -> Optional[Tuple[int, ...]]: min_len = 0 - min_key = None - for k in self.cache.iterkeys(): + min_key: Optional[Tuple[int, ...]] = None + for k in self.cache.iterkeys(): # type: ignore prefix_len = Llama.longest_token_prefix(k, key) if prefix_len > min_len: min_len = prefix_len - min_key = k + min_key = k # type: ignore return min_key def __getitem__(self, key: Sequence[int]) -> "LlamaState": @@ -130,29 +140,36 @@ class LlamaDiskCache(LlamaCache): _key = self._find_longest_prefix_key(key) if _key is None: raise KeyError("Key not found") - value = self.cache.pop(_key) - self.cache.push(_key) + value: "LlamaState" = self.cache.pop(_key) # type: ignore + self.cache.push(_key, side="front") # type: ignore return value + def __contains__(self, key: Sequence[int]) -> bool: + return self._find_longest_prefix_key(tuple(key)) is not None + def __setitem__(self, key: Sequence[int], value: "LlamaState"): + print("LlamaDiskCache.__setitem__: called", file=sys.stderr) key = tuple(key) if key in self.cache: + print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) del self.cache[key] self.cache[key] = value - while self.cache_size > self.capacity_bytes: + print("LlamaDiskCache.__setitem__: set", file=sys.stderr) + while self.cache_size > self.capacity_bytes and len(self.cache) > 0: key_to_remove = next(iter(self.cache)) del self.cache[key_to_remove] + print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) class LlamaState: def __init__( - self, - eval_tokens: Deque[int], - eval_logits: Deque[List[float]], - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] - llama_state_size: int, + self, + eval_tokens: Deque[int], + eval_logits: Deque[List[float]], + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] + llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits @@ -184,25 +201,25 @@ class Llama: """High-level Python wrapper for a llama.cpp model.""" def __init__( - self, - model_path: str, - # NOTE: These parameters are likely to change in the future. - n_ctx: int = 512, - n_parts: int = -1, - n_gpu_layers: int = 0, - seed: int = 1337, - f16_kv: bool = True, - logits_all: bool = False, - vocab_only: bool = False, - use_mmap: bool = True, - use_mlock: bool = False, - embedding: bool = False, - n_threads: Optional[int] = None, - n_batch: int = 512, - last_n_tokens_size: int = 64, - lora_base: Optional[str] = None, - lora_path: Optional[str] = None, - verbose: bool = True, + self, + model_path: str, + # NOTE: These parameters are likely to change in the future. + n_ctx: int = 512, + n_parts: int = -1, + n_gpu_layers: int = 0, + seed: int = 1337, + f16_kv: bool = True, + logits_all: bool = False, + vocab_only: bool = False, + use_mmap: bool = True, + use_mlock: bool = False, + embedding: bool = False, + n_threads: Optional[int] = None, + n_batch: int = 512, + last_n_tokens_size: int = 64, + lora_base: Optional[str] = None, + lora_path: Optional[str] = None, + verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -249,7 +266,7 @@ class Llama: self.eval_tokens: Deque[int] = deque(maxlen=n_ctx) self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1) - self.cache: Optional[LlamaCache] = None + self.cache: Optional[BaseLlamaCache] = None self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) @@ -271,12 +288,12 @@ class Llama: if self.lora_path: if llama_cpp.llama_apply_lora_from_file( - self.ctx, - llama_cpp.c_char_p(self.lora_path.encode("utf-8")), - llama_cpp.c_char_p(self.lora_base.encode("utf-8")) - if self.lora_base is not None - else llama_cpp.c_char_p(0), - llama_cpp.c_int(self.n_threads), + self.ctx, + llama_cpp.c_char_p(self.lora_path.encode("utf-8")), + llama_cpp.c_char_p(self.lora_base.encode("utf-8")) + if self.lora_base is not None + else llama_cpp.c_char_p(0), + llama_cpp.c_int(self.n_threads), ): raise RuntimeError( f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" @@ -363,7 +380,7 @@ class Llama: ) return output - def set_cache(self, cache: Optional[LlamaCache]): + def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. Args: @@ -387,7 +404,7 @@ class Llama: assert self.ctx is not None n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): - batch = tokens[i: min(len(tokens), i + self.n_batch)] + batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = min(n_ctx - len(batch), len(self._input_ids)) n_tokens = len(batch) return_code = llama_cpp.llama_eval( @@ -409,28 +426,28 @@ class Llama: n_vocab = self._n_vocab cols = n_vocab logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [logits_view[i * cols: (i + 1) * cols] for i in range(rows)] + logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) self._scores: npt.NDArray[np.single] = np.concatenate( (self._scores, np.array(logits, dtype=np.single)), axis=0 ) def _sample( - self, - last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] - last_n_tokens_size: llama_cpp.c_int, - top_k: llama_cpp.c_int, - top_p: llama_cpp.c_float, - temp: llama_cpp.c_float, - tfs_z: llama_cpp.c_float, - repeat_penalty: llama_cpp.c_float, - frequency_penalty: llama_cpp.c_float, - presence_penalty: llama_cpp.c_float, - mirostat_mode: llama_cpp.c_int, - mirostat_tau: llama_cpp.c_float, - mirostat_eta: llama_cpp.c_float, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, + self, + last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] + last_n_tokens_size: llama_cpp.c_int, + top_k: llama_cpp.c_int, + top_p: llama_cpp.c_float, + temp: llama_cpp.c_float, + tfs_z: llama_cpp.c_float, + repeat_penalty: llama_cpp.c_float, + frequency_penalty: llama_cpp.c_float, + presence_penalty: llama_cpp.c_float, + mirostat_mode: llama_cpp.c_int, + mirostat_tau: llama_cpp.c_float, + mirostat_eta: llama_cpp.c_float, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -550,19 +567,19 @@ class Llama: ) def sample( - self, - top_k: int = 40, - top_p: float = 0.95, - temp: float = 0.80, - repeat_penalty: float = 1.1, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_eta: float = 0.1, - mirostat_tau: float = 5.0, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, + self, + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_eta: float = 0.1, + mirostat_tau: float = 5.0, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): """Sample a token from the model. @@ -578,7 +595,7 @@ class Llama: assert self.ctx is not None last_n_tokens_data = [llama_cpp.llama_token(0)] * max( 0, self.last_n_tokens_size - len(self._input_ids) - ) + self._input_ids[-self.last_n_tokens_size:].tolist() + ) + self._input_ids[-self.last_n_tokens_size :].tolist() return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data @@ -599,21 +616,21 @@ class Llama: ) def generate( - self, - tokens: Sequence[int], - top_k: int = 40, - top_p: float = 0.95, - temp: float = 0.80, - repeat_penalty: float = 1.1, - reset: bool = True, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, + self, + tokens: Sequence[int], + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + reset: bool = True, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -676,7 +693,7 @@ class Llama: logits_processor=logits_processor, ) if stopping_criteria is not None and stopping_criteria( - self._input_ids.tolist(), self._scores[-1, :].tolist() + self._input_ids.tolist(), self._scores[-1, :].tolist() ): return tokens_or_none = yield token @@ -685,7 +702,7 @@ class Llama: tokens.extend(tokens_or_none) def create_embedding( - self, input: Union[str, List[str]], model: Optional[str] = None + self, input: Union[str, List[str]], model: Optional[str] = None ) -> Embedding: """Embed a string. @@ -720,8 +737,8 @@ class Llama: n_tokens = len(tokens) total_tokens += n_tokens embedding = llama_cpp.llama_get_embeddings(self.ctx)[ - : llama_cpp.llama_n_embd(self.ctx) - ] + : llama_cpp.llama_n_embd(self.ctx) + ] data.append( { @@ -755,27 +772,27 @@ class Llama: return list(map(float, self.create_embedding(input)["data"][0]["embedding"])) def _create_completion( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 16, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 16, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -827,19 +844,19 @@ class Llama: finish_reason = "length" multibyte_fix = 0 for token in self.generate( - prompt_tokens, - top_k=top_k, - top_p=top_p, - temp=temperature, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - repeat_penalty=repeat_penalty, - stopping_criteria=stopping_criteria, - logits_processor=logits_processor, + prompt_tokens, + top_k=top_k, + top_p=top_p, + temp=temperature, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + repeat_penalty=repeat_penalty, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -891,7 +908,7 @@ class Llama: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token if token_end_position >= ( - remaining_length - first_stop_position - 1 + remaining_length - first_stop_position - 1 ): break logprobs_or_none: Optional[CompletionLogprobs] = None @@ -952,7 +969,7 @@ class Llama: break if stopping_criteria is not None and stopping_criteria( - self._input_ids.tolist(), self._scores[-1, :].tolist() + self._input_ids.tolist(), self._scores[-1, :].tolist() ): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -1017,8 +1034,8 @@ class Llama: "choices": [ { "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), + : len(last_text) - (token_end_position - end) + ].decode("utf-8", errors="ignore"), "index": 0, "logprobs": logprobs_or_none, "finish_reason": finish_reason, @@ -1049,6 +1066,7 @@ class Llama: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) self.cache[prompt_tokens + completion_tokens] = self.save_state() + print("Llama._create_completion: cache saved", file=sys.stderr) return if self.cache: @@ -1084,10 +1102,10 @@ class Llama: for token in all_tokens ] all_logprobs = [ - Llama.logits_to_logprobs(row.tolist()) for row in self._scores - ][token_offset:] + Llama.logits_to_logprobs(row.tolist()) for row in self._scores + ][token_offset:] for token, token_str, logprobs_token in zip( - all_tokens, all_token_strs, all_logprobs + all_tokens, all_token_strs, all_logprobs ): text_offsets.append(text_offset) text_offset += len(token_str) @@ -1138,27 +1156,27 @@ class Llama: } def create_completion( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 128, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1211,27 +1229,27 @@ class Llama: return completion def __call__( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 128, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1279,7 +1297,7 @@ class Llama: ) def _convert_text_completion_to_chat( - self, completion: Completion + self, completion: Completion ) -> ChatCompletion: return { "id": "chat" + completion["id"], @@ -1300,8 +1318,8 @@ class Llama: } def _convert_text_completion_chunks_to_chat( - self, - chunks: Iterator[CompletionChunk], + self, + chunks: Iterator[CompletionChunk], ) -> Iterator[ChatCompletionChunk]: for i, chunk in enumerate(chunks): if i == 0: @@ -1337,22 +1355,22 @@ class Llama: } def create_chat_completion( - self, - messages: List[ChatCompletionMessage], - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, + self, + messages: List[ChatCompletionMessage], + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1453,9 +1471,17 @@ class Llama: def save_state(self) -> LlamaState: assert self.ctx is not None + if self.verbose: + print("Llama.save_state: saving llama state", file=sys.stderr) state_size = llama_cpp.llama_get_state_size(self.ctx) + if self.verbose: + print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr) llama_state = (llama_cpp.c_uint8 * int(state_size))() + if self.verbose: + print("Llama.save_state: allocated state", file=sys.stderr) n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state) + if self.verbose: + print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr) if int(n_bytes) > int(state_size): raise RuntimeError("Failed to copy llama state data") llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))() diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 23382e1..f70d8f0 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -58,6 +58,10 @@ class Settings(BaseSettings): default=False, description="Use a cache to reduce processing times for evaluated prompts.", ) + cache_type: Literal["ram", "disk"] = Field( + default="ram", + description="The type of cache to use. Only used if cache is True.", + ) cache_size: int = Field( default=2 << 30, description="The size of the cache in bytes. Only used if cache is True.", @@ -108,6 +112,11 @@ def create_app(settings: Optional[Settings] = None): verbose=settings.verbose, ) if settings.cache: + if settings.cache_type == "disk": + cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size) + else: + cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) + cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size) llama.set_cache(cache)