diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 8cd77ee..7a8c25b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -390,18 +390,28 @@ class Llama: """ assert self.ctx is not None - if ( - reset - and len(self.eval_tokens) > 0 - and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)]) - ): - if self.verbose: - print("Llama.generate: cache hit", file=sys.stderr) - reset = False - tokens = tokens[len(self.eval_tokens) :] + if reset and len(self.eval_tokens) > 0: + longest_prefix = 0 + for a, b in zip(self.eval_tokens, tokens[:-1]): + if a == b: + longest_prefix += 1 + else: + break + if longest_prefix > 0: + if self.verbose: + print("Llama.generate: prefix-match hit", file=sys.stderr) + reset = False + tokens = tokens[longest_prefix:] + for _ in range(len(self.eval_tokens) - longest_prefix): + self.eval_tokens.pop() + try: + self.eval_logits.pop() + except IndexError: + pass if reset: self.reset() + while True: self.eval(tokens) token = self.sample(