Rewind model to longest prefix.
This commit is contained in:
parent
cabd8b8ed1
commit
97c6372350
1 changed files with 19 additions and 9 deletions
|
@ -390,18 +390,28 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
if (
|
if reset and len(self.eval_tokens) > 0:
|
||||||
reset
|
longest_prefix = 0
|
||||||
and len(self.eval_tokens) > 0
|
for a, b in zip(self.eval_tokens, tokens[:-1]):
|
||||||
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
|
if a == b:
|
||||||
):
|
longest_prefix += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if longest_prefix > 0:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Llama.generate: cache hit", file=sys.stderr)
|
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||||
reset = False
|
reset = False
|
||||||
tokens = tokens[len(self.eval_tokens) :]
|
tokens = tokens[longest_prefix:]
|
||||||
|
for _ in range(len(self.eval_tokens) - longest_prefix):
|
||||||
|
self.eval_tokens.pop()
|
||||||
|
try:
|
||||||
|
self.eval_logits.pop()
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
token = self.sample(
|
token = self.sample(
|
||||||
|
|
Loading…
Add table
Reference in a new issue