Implement logprobs parameter for text completion. Closes #2

This commit is contained in:
Andrei Betlen 2023-04-12 14:05:11 -04:00
parent 2a60eb820f
commit b3805bb9cc
2 changed files with 111 additions and 16 deletions

View file

@ -2,6 +2,7 @@ import os
import sys import sys
import uuid import uuid
import time import time
import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator from typing import List, Optional, Union, Generator, Sequence, Iterator
from collections import deque from collections import deque
@ -76,6 +77,9 @@ class Llama:
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.n_tokens = 0
self.n_past = 0
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
@ -136,6 +140,9 @@ class Llama:
[llama_cpp.llama_token(0)] * self.last_n_tokens_size [llama_cpp.llama_token(0)] * self.last_n_tokens_size
) )
self.tokens_consumed = 0 self.tokens_consumed = 0
self.n_tokens = 0
self.n_past = 0
self.all_logits = []
def eval(self, tokens: Sequence[llama_cpp.llama_token]): def eval(self, tokens: Sequence[llama_cpp.llama_token]):
"""Evaluate a list of tokens. """Evaluate a list of tokens.
@ -147,18 +154,31 @@ class Llama:
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), self.tokens_consumed) self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
self.n_tokens = len(batch)
return_code = llama_cpp.llama_eval( return_code = llama_cpp.llama_eval(
ctx=self.ctx, ctx=self.ctx,
tokens=(llama_cpp.llama_token * len(batch))(*batch), tokens=(llama_cpp.llama_token * len(batch))(*batch),
n_tokens=llama_cpp.c_int(len(batch)), n_tokens=llama_cpp.c_int(self.n_tokens),
n_past=llama_cpp.c_int(n_past), n_past=llama_cpp.c_int(self.n_past),
n_threads=llama_cpp.c_int(self.n_threads), n_threads=llama_cpp.c_int(self.n_threads),
) )
if int(return_code) != 0: if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}") raise RuntimeError(f"llama_eval returned {return_code}")
self.last_n_tokens_data.extend(batch) self.last_n_tokens_data.extend(batch)
self.tokens_consumed += len(batch) self.tokens_consumed += len(batch)
if self.params.logits_all:
self.all_logits.extend(self._logits())
def _logits(self) -> List[List[float]]:
"""Return the logits from the last call to llama_eval."""
assert self.ctx is not None
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
cols = int(n_vocab)
rows = self.n_tokens if self.params.logits_all else 1
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
return logits
def sample( def sample(
self, self,
@ -327,14 +347,55 @@ class Llama:
else: else:
stop_sequences = [] stop_sequences = []
finish_reason = None text_offset = 0
for token in self.generate( text_offsets: List[int] = []
prompt_tokens, token_logprobs: List[float] = []
top_k=top_k, tokens: List[str] = []
top_p=top_p, top_logprobs: List[Dict[str, float]] = []
temp=temperature,
repeat_penalty=repeat_penalty, self.reset()
): self.eval(prompt_tokens)
if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
if logprobs is not None:
token_strs = [
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
]
logprobs_all = [
[Llama.logit_to_logprob(logit) for logit in row]
for row in self.all_logits
]
for token, token_str, logprobs_token in zip(
prompt_tokens, token_strs, logprobs_all
):
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
top_logprobs.append(top_logprob)
finish_reason = "length"
while True:
token = self.sample(
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty,
)
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "stop" finish_reason = "stop"
@ -377,13 +438,35 @@ class Llama:
} }
], ],
} }
if logprobs is not None:
# TODO: Confirm wether this should happen before or after
# next eval.
token_str = self.detokenize([token]).decode("utf-8")
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)
logprobs_token = [
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
]
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: logprobs_token[int(token)]})
top_logprobs.append(top_logprob)
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "length" finish_reason = "length"
break break
self.eval([token])
if finish_reason is None:
finish_reason = "length"
if stream: if stream:
yield { yield {
@ -410,8 +493,14 @@ class Llama:
if suffix is not None: if suffix is not None:
text = text + suffix text = text + suffix
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None: if logprobs is not None:
raise NotImplementedError("logprobs not implemented") logprobs_or_none = {
"tokens": tokens,
"text_offset": text_offsets,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
}
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self.ctx) llama_cpp.llama_print_timings(self.ctx)
@ -425,7 +514,7 @@ class Llama:
{ {
"text": text, "text": text,
"index": 0, "index": 0,
"logprobs": None, "logprobs": logprobs_or_none,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
], ],
@ -704,3 +793,7 @@ class Llama:
def token_bos() -> llama_cpp.llama_token: def token_bos() -> llama_cpp.llama_token:
"""Return the beginning-of-sequence token.""" """Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos() return llama_cpp.llama_token_bos()
@staticmethod
def logit_to_logprob(x: float) -> float:
return math.log(1.0 + math.exp(x))

View file

@ -33,6 +33,7 @@ class Settings(BaseSettings):
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
embedding: bool = True embedding: bool = True
last_n_tokens_size: int = 64 last_n_tokens_size: int = 64
logits_all: bool = False
app = FastAPI( app = FastAPI(
@ -52,6 +53,7 @@ llama = llama_cpp.Llama(
f16_kv=settings.f16_kv, f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock, use_mlock=settings.use_mlock,
embedding=settings.embedding, embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads, n_threads=settings.n_threads,
n_batch=settings.n_batch, n_batch=settings.n_batch,
n_ctx=settings.n_ctx, n_ctx=settings.n_ctx,