This commit is contained in:
Andrei Betlen 2023-05-07 13:19:56 -04:00
parent 4f8cf52a38
commit 2753b85321

View file

@ -127,9 +127,7 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(
maxlen=n_ctx if logits_all else 1
)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
self.cache: Optional[LlamaCache] = None
@ -547,12 +545,6 @@ class Llama:
finish_reason = "stop"
break
if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)
self.cache[prompt_tokens] = self.save_state()
completion_tokens.append(token)
all_text = self.detokenize(completion_tokens)
@ -611,6 +603,11 @@ class Llama:
finish_reason = "length"
break
if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[prompt_tokens + completion_tokens] = self.save_state()
if stream:
yield {
"id": completion_id,