diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2a96ff8..b88fd65 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,6 +4,7 @@ import uuid import time import math import multiprocessing +from abc import ABC from typing import ( List, Optional, @@ -17,6 +18,8 @@ from typing import ( ) from collections import deque, OrderedDict +import diskcache + from . import llama_cpp from .llama_types import * @@ -24,20 +27,47 @@ import numpy as np import numpy.typing as npt -class LlamaCache: - """Cache for a llama.cpp model.""" +class LlamaCache(ABC): + """Base cache class for a llama.cpp model.""" def __init__(self, capacity_bytes: int = (2 << 30)): - self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() + pass + + @property + def cache_size(self): + return 0 + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + pass + + def __getitem__(self, key: Sequence[int]) -> "LlamaState": + pass + + def __contains__(self, key: Sequence[int]) -> bool: + pass + + def __setitem__(self, key: Sequence[int], value: "LlamaState"): + pass + + +class LlamaRAMCache(LlamaCache): + """Cache for a llama.cpp model using RAM.""" + + def __init__(self, capacity_bytes: int = (2 << 30)): + super().__init__(capacity_bytes) self.capacity_bytes = capacity_bytes + self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict() @property def cache_size(self): return sum([state.llama_state_size for state in self.cache_state.values()]) def _find_longest_prefix_key( - self, - key: Tuple[int, ...], + self, + key: Tuple[int, ...], ) -> Optional[Tuple[int, ...]]: min_len = 0 min_key = None @@ -54,7 +84,7 @@ class LlamaCache: key = tuple(key) _key = self._find_longest_prefix_key(key) if _key is None: - raise KeyError(f"Key not found") + raise KeyError("Key not found") value = self.cache_state[_key] self.cache_state.move_to_end(_key) return value @@ -71,15 +101,58 @@ class LlamaCache: self.cache_state.popitem(last=False) +class LlamaDiskCache(LlamaCache): + """Cache for a llama.cpp model using disk.""" + + def __init__(self, cache_dir="./llama_cache", capacity_bytes: int = (2 << 30)): + super().__init__(capacity_bytes) + self.cache = diskcache.Cache(cache_dir) + + @property + def cache_size(self): + return self.cache.volume() + + def _find_longest_prefix_key( + self, + key: Tuple[int, ...], + ) -> Optional[Tuple[int, ...]]: + min_len = 0 + min_key = None + for k in self.cache.iterkeys(): + prefix_len = Llama.longest_token_prefix(k, key) + if prefix_len > min_len: + min_len = prefix_len + min_key = k + return min_key + + def __getitem__(self, key: Sequence[int]) -> "LlamaState": + key = tuple(key) + _key = self._find_longest_prefix_key(key) + if _key is None: + raise KeyError("Key not found") + value = self.cache.pop(_key) + self.cache.push(_key) + return value + + def __setitem__(self, key: Sequence[int], value: "LlamaState"): + key = tuple(key) + if key in self.cache: + del self.cache[key] + self.cache[key] = value + while self.cache_size > self.capacity_bytes: + key_to_remove = next(iter(self.cache)) + del self.cache[key_to_remove] + + class LlamaState: def __init__( - self, - eval_tokens: Deque[int], - eval_logits: Deque[List[float]], - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] - llama_state_size: int, + self, + eval_tokens: Deque[int], + eval_logits: Deque[List[float]], + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] + llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits @@ -111,25 +184,25 @@ class Llama: """High-level Python wrapper for a llama.cpp model.""" def __init__( - self, - model_path: str, - # NOTE: These parameters are likely to change in the future. - n_ctx: int = 512, - n_parts: int = -1, - n_gpu_layers: int = 0, - seed: int = 1337, - f16_kv: bool = True, - logits_all: bool = False, - vocab_only: bool = False, - use_mmap: bool = True, - use_mlock: bool = False, - embedding: bool = False, - n_threads: Optional[int] = None, - n_batch: int = 512, - last_n_tokens_size: int = 64, - lora_base: Optional[str] = None, - lora_path: Optional[str] = None, - verbose: bool = True, + self, + model_path: str, + # NOTE: These parameters are likely to change in the future. + n_ctx: int = 512, + n_parts: int = -1, + n_gpu_layers: int = 0, + seed: int = 1337, + f16_kv: bool = True, + logits_all: bool = False, + vocab_only: bool = False, + use_mmap: bool = True, + use_mlock: bool = False, + embedding: bool = False, + n_threads: Optional[int] = None, + n_batch: int = 512, + last_n_tokens_size: int = 64, + lora_base: Optional[str] = None, + lora_path: Optional[str] = None, + verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -198,12 +271,12 @@ class Llama: if self.lora_path: if llama_cpp.llama_apply_lora_from_file( - self.ctx, - llama_cpp.c_char_p(self.lora_path.encode("utf-8")), - llama_cpp.c_char_p(self.lora_base.encode("utf-8")) - if self.lora_base is not None - else llama_cpp.c_char_p(0), - llama_cpp.c_int(self.n_threads), + self.ctx, + llama_cpp.c_char_p(self.lora_path.encode("utf-8")), + llama_cpp.c_char_p(self.lora_base.encode("utf-8")) + if self.lora_base is not None + else llama_cpp.c_char_p(0), + llama_cpp.c_int(self.n_threads), ): raise RuntimeError( f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" @@ -314,7 +387,7 @@ class Llama: assert self.ctx is not None n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): - batch = tokens[i : min(len(tokens), i + self.n_batch)] + batch = tokens[i: min(len(tokens), i + self.n_batch)] n_past = min(n_ctx - len(batch), len(self._input_ids)) n_tokens = len(batch) return_code = llama_cpp.llama_eval( @@ -336,28 +409,28 @@ class Llama: n_vocab = self._n_vocab cols = n_vocab logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] + logits = [logits_view[i * cols: (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) self._scores: npt.NDArray[np.single] = np.concatenate( (self._scores, np.array(logits, dtype=np.single)), axis=0 ) def _sample( - self, - last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] - last_n_tokens_size: llama_cpp.c_int, - top_k: llama_cpp.c_int, - top_p: llama_cpp.c_float, - temp: llama_cpp.c_float, - tfs_z: llama_cpp.c_float, - repeat_penalty: llama_cpp.c_float, - frequency_penalty: llama_cpp.c_float, - presence_penalty: llama_cpp.c_float, - mirostat_mode: llama_cpp.c_int, - mirostat_tau: llama_cpp.c_float, - mirostat_eta: llama_cpp.c_float, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, + self, + last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] + last_n_tokens_size: llama_cpp.c_int, + top_k: llama_cpp.c_int, + top_p: llama_cpp.c_float, + temp: llama_cpp.c_float, + tfs_z: llama_cpp.c_float, + repeat_penalty: llama_cpp.c_float, + frequency_penalty: llama_cpp.c_float, + presence_penalty: llama_cpp.c_float, + mirostat_mode: llama_cpp.c_int, + mirostat_tau: llama_cpp.c_float, + mirostat_eta: llama_cpp.c_float, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -477,19 +550,19 @@ class Llama: ) def sample( - self, - top_k: int = 40, - top_p: float = 0.95, - temp: float = 0.80, - repeat_penalty: float = 1.1, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_eta: float = 0.1, - mirostat_tau: float = 5.0, - penalize_nl: bool = True, - logits_processor: Optional[LogitsProcessorList] = None, + self, + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_eta: float = 0.1, + mirostat_tau: float = 5.0, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, ): """Sample a token from the model. @@ -505,7 +578,7 @@ class Llama: assert self.ctx is not None last_n_tokens_data = [llama_cpp.llama_token(0)] * max( 0, self.last_n_tokens_size - len(self._input_ids) - ) + self._input_ids[-self.last_n_tokens_size :].tolist() + ) + self._input_ids[-self.last_n_tokens_size:].tolist() return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data @@ -526,21 +599,21 @@ class Llama: ) def generate( - self, - tokens: Sequence[int], - top_k: int = 40, - top_p: float = 0.95, - temp: float = 0.80, - repeat_penalty: float = 1.1, - reset: bool = True, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, + self, + tokens: Sequence[int], + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + reset: bool = True, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -603,7 +676,7 @@ class Llama: logits_processor=logits_processor, ) if stopping_criteria is not None and stopping_criteria( - self._input_ids.tolist(), self._scores[-1, :].tolist() + self._input_ids.tolist(), self._scores[-1, :].tolist() ): return tokens_or_none = yield token @@ -612,7 +685,7 @@ class Llama: tokens.extend(tokens_or_none) def create_embedding( - self, input: Union[str, List[str]], model: Optional[str] = None + self, input: Union[str, List[str]], model: Optional[str] = None ) -> Embedding: """Embed a string. @@ -647,8 +720,8 @@ class Llama: n_tokens = len(tokens) total_tokens += n_tokens embedding = llama_cpp.llama_get_embeddings(self.ctx)[ - : llama_cpp.llama_n_embd(self.ctx) - ] + : llama_cpp.llama_n_embd(self.ctx) + ] data.append( { @@ -682,27 +755,27 @@ class Llama: return list(map(float, self.create_embedding(input)["data"][0]["embedding"])) def _create_completion( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 16, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 16, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -754,19 +827,19 @@ class Llama: finish_reason = "length" multibyte_fix = 0 for token in self.generate( - prompt_tokens, - top_k=top_k, - top_p=top_p, - temp=temperature, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - repeat_penalty=repeat_penalty, - stopping_criteria=stopping_criteria, - logits_processor=logits_processor, + prompt_tokens, + top_k=top_k, + top_p=top_p, + temp=temperature, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + repeat_penalty=repeat_penalty, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -818,7 +891,7 @@ class Llama: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token if token_end_position >= ( - remaining_length - first_stop_position - 1 + remaining_length - first_stop_position - 1 ): break logprobs_or_none: Optional[CompletionLogprobs] = None @@ -879,7 +952,7 @@ class Llama: break if stopping_criteria is not None and stopping_criteria( - self._input_ids.tolist(), self._scores[-1, :].tolist() + self._input_ids.tolist(), self._scores[-1, :].tolist() ): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -944,8 +1017,8 @@ class Llama: "choices": [ { "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), + : len(last_text) - (token_end_position - end) + ].decode("utf-8", errors="ignore"), "index": 0, "logprobs": logprobs_or_none, "finish_reason": finish_reason, @@ -1011,10 +1084,10 @@ class Llama: for token in all_tokens ] all_logprobs = [ - Llama.logits_to_logprobs(row.tolist()) for row in self._scores - ][token_offset:] + Llama.logits_to_logprobs(row.tolist()) for row in self._scores + ][token_offset:] for token, token_str, logprobs_token in zip( - all_tokens, all_token_strs, all_logprobs + all_tokens, all_token_strs, all_logprobs ): text_offsets.append(text_offset) text_offset += len(token_str) @@ -1065,27 +1138,27 @@ class Llama: } def create_completion( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 128, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1138,27 +1211,27 @@ class Llama: return completion def __call__( - self, - prompt: str, - suffix: Optional[str] = None, - max_tokens: int = 128, - temperature: float = 0.8, - top_p: float = 0.95, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.1, - top_k: int = 40, - stream: bool = False, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1206,7 +1279,7 @@ class Llama: ) def _convert_text_completion_to_chat( - self, completion: Completion + self, completion: Completion ) -> ChatCompletion: return { "id": "chat" + completion["id"], @@ -1227,8 +1300,8 @@ class Llama: } def _convert_text_completion_chunks_to_chat( - self, - chunks: Iterator[CompletionChunk], + self, + chunks: Iterator[CompletionChunk], ) -> Iterator[ChatCompletionChunk]: for i, chunk in enumerate(chunks): if i == 0: @@ -1264,22 +1337,22 @@ class Llama: } def create_chat_completion( - self, - messages: List[ChatCompletionMessage], - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, + self, + messages: List[ChatCompletionMessage], + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. diff --git a/pyproject.toml b/pyproject.toml index 09991e9..ccbefe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ include = [ python = "^3.8.1" typing-extensions = "^4.6.3" numpy = "^1.20.0" +diskcache = "^5.6.1" uvicorn = { version = "^0.22.0", optional = true } fastapi = { version = "^0.96.0", optional = true } sse-starlette = { version = "^1.6.1", optional = true } diff --git a/setup.py b/setup.py index 7a0cdc3..4b0d7cb 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( license="MIT", package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"}, packages=["llama_cpp", "llama_cpp.server"], - install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"], + install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1"], extras_require={ "server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"], },