Format
This commit is contained in:
parent
4f8cf52a38
commit
2753b85321
1 changed files with 6 additions and 9 deletions
|
@ -127,9 +127,7 @@ class Llama:
|
|||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[float]] = deque(
|
||||
maxlen=n_ctx if logits_all else 1
|
||||
)
|
||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
|
||||
|
||||
self.cache: Optional[LlamaCache] = None
|
||||
|
||||
|
@ -547,12 +545,6 @@ class Llama:
|
|||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
if self.verbose:
|
||||
print("Llama._create_completion: cache miss", file=sys.stderr)
|
||||
self.cache[prompt_tokens] = self.save_state()
|
||||
|
||||
completion_tokens.append(token)
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
|
@ -611,6 +603,11 @@ class Llama:
|
|||
finish_reason = "length"
|
||||
break
|
||||
|
||||
if self.cache:
|
||||
if self.verbose:
|
||||
print("Llama._create_completion: cache save", file=sys.stderr)
|
||||
self.cache[prompt_tokens + completion_tokens] = self.save_state()
|
||||
|
||||
if stream:
|
||||
yield {
|
||||
"id": completion_id,
|
||||
|
|
Loading…
Reference in a new issue