Implement logprobs parameter for text completion. Closes #2
This commit is contained in:
parent
2a60eb820f
commit
b3805bb9cc
2 changed files with 111 additions and 16 deletions
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import math
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||
from collections import deque
|
||||
|
@ -76,6 +77,9 @@ class Llama:
|
|||
)
|
||||
self.tokens_consumed = 0
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.n_tokens = 0
|
||||
self.n_past = 0
|
||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
|
||||
|
@ -136,6 +140,9 @@ class Llama:
|
|||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
||||
)
|
||||
self.tokens_consumed = 0
|
||||
self.n_tokens = 0
|
||||
self.n_past = 0
|
||||
self.all_logits = []
|
||||
|
||||
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
||||
"""Evaluate a list of tokens.
|
||||
|
@ -147,18 +154,31 @@ class Llama:
|
|||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||
for i in range(0, len(tokens), self.n_batch):
|
||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||
n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||
self.n_tokens = len(batch)
|
||||
return_code = llama_cpp.llama_eval(
|
||||
ctx=self.ctx,
|
||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||
n_tokens=llama_cpp.c_int(len(batch)),
|
||||
n_past=llama_cpp.c_int(n_past),
|
||||
n_tokens=llama_cpp.c_int(self.n_tokens),
|
||||
n_past=llama_cpp.c_int(self.n_past),
|
||||
n_threads=llama_cpp.c_int(self.n_threads),
|
||||
)
|
||||
if int(return_code) != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
self.last_n_tokens_data.extend(batch)
|
||||
self.tokens_consumed += len(batch)
|
||||
if self.params.logits_all:
|
||||
self.all_logits.extend(self._logits())
|
||||
|
||||
def _logits(self) -> List[List[float]]:
|
||||
"""Return the logits from the last call to llama_eval."""
|
||||
assert self.ctx is not None
|
||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||
cols = int(n_vocab)
|
||||
rows = self.n_tokens if self.params.logits_all else 1
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
@ -327,14 +347,55 @@ class Llama:
|
|||
else:
|
||||
stop_sequences = []
|
||||
|
||||
finish_reason = None
|
||||
for token in self.generate(
|
||||
prompt_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
):
|
||||
text_offset = 0
|
||||
text_offsets: List[int] = []
|
||||
token_logprobs: List[float] = []
|
||||
tokens: List[str] = []
|
||||
top_logprobs: List[Dict[str, float]] = []
|
||||
|
||||
self.reset()
|
||||
self.eval(prompt_tokens)
|
||||
|
||||
if logprobs is not None and self.params.logits_all is False:
|
||||
raise ValueError(
|
||||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
|
||||
if logprobs is not None:
|
||||
token_strs = [
|
||||
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
|
||||
]
|
||||
logprobs_all = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
for row in self.all_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
prompt_tokens, token_strs, logprobs_all
|
||||
):
|
||||
text_offsets.append(text_offset)
|
||||
text_offset += len(token_str)
|
||||
tokens.append(token_str)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||
top_logprobs.append(top_logprob)
|
||||
|
||||
finish_reason = "length"
|
||||
while True:
|
||||
token = self.sample(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "stop"
|
||||
|
@ -377,13 +438,35 @@ class Llama:
|
|||
}
|
||||
],
|
||||
}
|
||||
|
||||
if logprobs is not None:
|
||||
# TODO: Confirm wether this should happen before or after
|
||||
# next eval.
|
||||
token_str = self.detokenize([token]).decode("utf-8")
|
||||
text_offsets.append(text_offset)
|
||||
text_offset += len(token_str)
|
||||
tokens.append(token_str)
|
||||
logprobs_token = [
|
||||
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
|
||||
]
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||
top_logprobs.append(top_logprob)
|
||||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "length"
|
||||
break
|
||||
|
||||
if finish_reason is None:
|
||||
finish_reason = "length"
|
||||
self.eval([token])
|
||||
|
||||
if stream:
|
||||
yield {
|
||||
|
@ -410,8 +493,14 @@ class Llama:
|
|||
if suffix is not None:
|
||||
text = text + suffix
|
||||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
raise NotImplementedError("logprobs not implemented")
|
||||
logprobs_or_none = {
|
||||
"tokens": tokens,
|
||||
"text_offset": text_offsets,
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
}
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self.ctx)
|
||||
|
@ -425,7 +514,7 @@ class Llama:
|
|||
{
|
||||
"text": text,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
|
@ -704,3 +793,7 @@ class Llama:
|
|||
def token_bos() -> llama_cpp.llama_token:
|
||||
"""Return the beginning-of-sequence token."""
|
||||
return llama_cpp.llama_token_bos()
|
||||
|
||||
@staticmethod
|
||||
def logit_to_logprob(x: float) -> float:
|
||||
return math.log(1.0 + math.exp(x))
|
||||
|
|
|
@ -33,6 +33,7 @@ class Settings(BaseSettings):
|
|||
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
|
||||
embedding: bool = True
|
||||
last_n_tokens_size: int = 64
|
||||
logits_all: bool = False
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
|
@ -52,6 +53,7 @@ llama = llama_cpp.Llama(
|
|||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
embedding=settings.embedding,
|
||||
logits_all=settings.logits_all,
|
||||
n_threads=settings.n_threads,
|
||||
n_batch=settings.n_batch,
|
||||
n_ctx=settings.n_ctx,
|
||||
|
|
Loading…
Reference in a new issue