Implement logprobs parameter for text completion. Closes #2

This commit is contained in:
Andrei Betlen 2023-04-12 14:05:11 -04:00
parent 2a60eb820f
commit b3805bb9cc
2 changed files with 111 additions and 16 deletions

View file

@ -2,6 +2,7 @@ import os
import sys
import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque
@ -76,6 +77,9 @@ class Llama:
)
self.tokens_consumed = 0
self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
@ -136,6 +140,9 @@ class Llama:
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
)
self.tokens_consumed = 0
self.n_tokens = 0
self.n_past = 0
self.all_logits = []
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens.
@ -147,18 +154,31 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_tokens = len(batch)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(len(batch)),
n_past=llama_cpp.c_int(n_past),
n_tokens=llama_cpp.c_int(self.n_tokens),
n_past=llama_cpp.c_int(self.n_past),
n_threads=llama_cpp.c_int(self.n_threads),
)
if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch)
if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
return logits
def sample(
self,
@ -327,14 +347,55 @@ class Llama:
else:
stop_sequences = []
finish_reason = None
for token in self.generate(
prompt_tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
):
text_offset = 0
text_offsets: List[int] = []
token_logprobs: List[float] = []
tokens: List[str] = []
top_logprobs: List[Dict[str, float]] = []
self.reset()
self.eval(prompt_tokens)
if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
if logprobs is not None:
token_strs = [
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
]
logprobs_all = [
[Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits
]
for token, token_str, logprobs_token in zip(
prompt_tokens, token_strs, logprobs_all
):
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
top_logprobs.append(top_logprob)
finish_reason = "length"
while True:
token = self.sample(
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
)
if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens)
finish_reason = "stop"
@ -377,13 +438,35 @@ class Llama:
}
],
}
if logprobs is not None:
# TODO: Confirm wether this should happen before or after
# next eval.
token_str = self.detokenize([token]).decode("utf-8")
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
logprobs_token = [
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
]
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: logprobs_token[int(token)]})
top_logprobs.append(top_logprob)
if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)
finish_reason = "length"
break
if finish_reason is None:
finish_reason = "length"
self.eval([token])
if stream:
yield {
@ -410,8 +493,14 @@ class Llama:
if suffix is not None:
text = text + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
raise NotImplementedError("logprobs not implemented")
logprobs_or_none = {
"tokens": tokens,
"text_offset": text_offsets,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
}
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
@ -425,7 +514,7 @@ class Llama:
{
"text": text,
"index": 0,
"logprobs": None,
"logprobs": logprobs_or_none,
"finish_reason": finish_reason,
}
],
@ -704,3 +793,7 @@ class Llama:
def token_bos() -> llama_cpp.llama_token:
"""Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos()
@staticmethod
def logit_to_logprob(x: float) -> float:
return math.log(1.0 + math.exp(x))

View file

@ -33,6 +33,7 @@ class Settings(BaseSettings):
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True
last_n_tokens_size: int = 64
logits_all: bool = False
app = FastAPI(
@ -52,6 +53,7 @@ llama = llama_cpp.Llama(
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads,
n_batch=settings.n_batch,
n_ctx=settings.n_ctx,