diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4f10227..064b982 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -735,10 +735,10 @@ class Llama: try: cache_item = self.cache[prompt_tokens] cache_prefix_len = Llama.longest_token_prefix( - cache_item.eval_tokens, prompt_tokens + cache_item.input_ids.tolist(), prompt_tokens ) eval_prefix_len = Llama.longest_token_prefix( - self.eval_tokens, prompt_tokens + self._input_ids.tolist(), prompt_tokens ) if cache_prefix_len > eval_prefix_len: self.load_state(cache_item)