diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 0db5c10..6836ea5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -127,9 +127,7 @@ class Llama: self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) - self.eval_logits: Deque[List[float]] = deque( - maxlen=n_ctx if logits_all else 1 - ) + self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1) self.cache: Optional[LlamaCache] = None @@ -547,12 +545,6 @@ class Llama: finish_reason = "stop" break - if self.cache and len(completion_tokens) == 0: - if prompt_tokens not in self.cache: - if self.verbose: - print("Llama._create_completion: cache miss", file=sys.stderr) - self.cache[prompt_tokens] = self.save_state() - completion_tokens.append(token) all_text = self.detokenize(completion_tokens) @@ -611,6 +603,11 @@ class Llama: finish_reason = "length" break + if self.cache: + if self.verbose: + print("Llama._create_completion: cache save", file=sys.stderr) + self.cache[prompt_tokens + completion_tokens] = self.save_state() + if stream: yield { "id": completion_id,