Implement logprobs parameter for text completion. Closes #2
This commit is contained in:
parent
2a60eb820f
commit
b3805bb9cc
2 changed files with 111 additions and 16 deletions
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -76,6 +77,9 @@ class Llama:
|
||||||
)
|
)
|
||||||
self.tokens_consumed = 0
|
self.tokens_consumed = 0
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
|
self.n_tokens = 0
|
||||||
|
self.n_past = 0
|
||||||
|
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
||||||
|
|
||||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||||
|
|
||||||
|
@ -136,6 +140,9 @@ class Llama:
|
||||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
||||||
)
|
)
|
||||||
self.tokens_consumed = 0
|
self.tokens_consumed = 0
|
||||||
|
self.n_tokens = 0
|
||||||
|
self.n_past = 0
|
||||||
|
self.all_logits = []
|
||||||
|
|
||||||
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
||||||
"""Evaluate a list of tokens.
|
"""Evaluate a list of tokens.
|
||||||
|
@ -147,18 +154,31 @@ class Llama:
|
||||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
||||||
|
self.n_tokens = len(batch)
|
||||||
return_code = llama_cpp.llama_eval(
|
return_code = llama_cpp.llama_eval(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||||
n_tokens=llama_cpp.c_int(len(batch)),
|
n_tokens=llama_cpp.c_int(self.n_tokens),
|
||||||
n_past=llama_cpp.c_int(n_past),
|
n_past=llama_cpp.c_int(self.n_past),
|
||||||
n_threads=llama_cpp.c_int(self.n_threads),
|
n_threads=llama_cpp.c_int(self.n_threads),
|
||||||
)
|
)
|
||||||
if int(return_code) != 0:
|
if int(return_code) != 0:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
self.last_n_tokens_data.extend(batch)
|
self.last_n_tokens_data.extend(batch)
|
||||||
self.tokens_consumed += len(batch)
|
self.tokens_consumed += len(batch)
|
||||||
|
if self.params.logits_all:
|
||||||
|
self.all_logits.extend(self._logits())
|
||||||
|
|
||||||
|
def _logits(self) -> List[List[float]]:
|
||||||
|
"""Return the logits from the last call to llama_eval."""
|
||||||
|
assert self.ctx is not None
|
||||||
|
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||||
|
cols = int(n_vocab)
|
||||||
|
rows = self.n_tokens if self.params.logits_all else 1
|
||||||
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
|
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
@ -327,14 +347,55 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
stop_sequences = []
|
stop_sequences = []
|
||||||
|
|
||||||
finish_reason = None
|
text_offset = 0
|
||||||
for token in self.generate(
|
text_offsets: List[int] = []
|
||||||
prompt_tokens,
|
token_logprobs: List[float] = []
|
||||||
|
tokens: List[str] = []
|
||||||
|
top_logprobs: List[Dict[str, float]] = []
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
self.eval(prompt_tokens)
|
||||||
|
|
||||||
|
if logprobs is not None and self.params.logits_all is False:
|
||||||
|
raise ValueError(
|
||||||
|
"logprobs is not supported for models created with logits_all=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
if logprobs is not None:
|
||||||
|
token_strs = [
|
||||||
|
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
|
||||||
|
]
|
||||||
|
logprobs_all = [
|
||||||
|
[Llama.logit_to_logprob(logit) for logit in row]
|
||||||
|
for row in self.all_logits
|
||||||
|
]
|
||||||
|
for token, token_str, logprobs_token in zip(
|
||||||
|
prompt_tokens, token_strs, logprobs_all
|
||||||
|
):
|
||||||
|
text_offsets.append(text_offset)
|
||||||
|
text_offset += len(token_str)
|
||||||
|
tokens.append(token_str)
|
||||||
|
sorted_logprobs = list(
|
||||||
|
sorted(
|
||||||
|
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
|
top_logprob = {
|
||||||
|
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||||
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
|
}
|
||||||
|
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||||
|
top_logprobs.append(top_logprob)
|
||||||
|
|
||||||
|
finish_reason = "length"
|
||||||
|
while True:
|
||||||
|
token = self.sample(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
):
|
)
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
@ -377,13 +438,35 @@ class Llama:
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if logprobs is not None:
|
||||||
|
# TODO: Confirm wether this should happen before or after
|
||||||
|
# next eval.
|
||||||
|
token_str = self.detokenize([token]).decode("utf-8")
|
||||||
|
text_offsets.append(text_offset)
|
||||||
|
text_offset += len(token_str)
|
||||||
|
tokens.append(token_str)
|
||||||
|
logprobs_token = [
|
||||||
|
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
|
||||||
|
]
|
||||||
|
sorted_logprobs = list(
|
||||||
|
sorted(
|
||||||
|
zip(logprobs_token, range(len(logprobs_token))), reverse=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
|
top_logprob = {
|
||||||
|
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
||||||
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
|
}
|
||||||
|
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||||
|
top_logprobs.append(top_logprob)
|
||||||
|
|
||||||
if len(completion_tokens) >= max_tokens:
|
if len(completion_tokens) >= max_tokens:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
break
|
break
|
||||||
|
self.eval([token])
|
||||||
if finish_reason is None:
|
|
||||||
finish_reason = "length"
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
yield {
|
yield {
|
||||||
|
@ -410,8 +493,14 @@ class Llama:
|
||||||
if suffix is not None:
|
if suffix is not None:
|
||||||
text = text + suffix
|
text = text + suffix
|
||||||
|
|
||||||
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
raise NotImplementedError("logprobs not implemented")
|
logprobs_or_none = {
|
||||||
|
"tokens": tokens,
|
||||||
|
"text_offset": text_offsets,
|
||||||
|
"token_logprobs": token_logprobs,
|
||||||
|
"top_logprobs": top_logprobs,
|
||||||
|
}
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_print_timings(self.ctx)
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
|
@ -425,7 +514,7 @@ class Llama:
|
||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": logprobs_or_none,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -704,3 +793,7 @@ class Llama:
|
||||||
def token_bos() -> llama_cpp.llama_token:
|
def token_bos() -> llama_cpp.llama_token:
|
||||||
"""Return the beginning-of-sequence token."""
|
"""Return the beginning-of-sequence token."""
|
||||||
return llama_cpp.llama_token_bos()
|
return llama_cpp.llama_token_bos()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def logit_to_logprob(x: float) -> float:
|
||||||
|
return math.log(1.0 + math.exp(x))
|
||||||
|
|
|
@ -33,6 +33,7 @@ class Settings(BaseSettings):
|
||||||
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
|
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
|
||||||
embedding: bool = True
|
embedding: bool = True
|
||||||
last_n_tokens_size: int = 64
|
last_n_tokens_size: int = 64
|
||||||
|
logits_all: bool = False
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
@ -52,6 +53,7 @@ llama = llama_cpp.Llama(
|
||||||
f16_kv=settings.f16_kv,
|
f16_kv=settings.f16_kv,
|
||||||
use_mlock=settings.use_mlock,
|
use_mlock=settings.use_mlock,
|
||||||
embedding=settings.embedding,
|
embedding=settings.embedding,
|
||||||
|
logits_all=settings.logits_all,
|
||||||
n_threads=settings.n_threads,
|
n_threads=settings.n_threads,
|
||||||
n_batch=settings.n_batch,
|
n_batch=settings.n_batch,
|
||||||
n_ctx=settings.n_ctx,
|
n_ctx=settings.n_ctx,
|
||||||
|
|
Loading…
Add table
Reference in a new issue