Fix logprobs for completions and implement for streaming logprobs.
This commit is contained in:
parent
a634a2453b
commit
17d4271b04
1 changed files with 103 additions and 22 deletions
|
@ -710,22 +710,56 @@ class Llama:
|
|||
# We want to avoid yielding any characters from
|
||||
# the generated text if they are part of a stop
|
||||
# sequence.
|
||||
longest = 0
|
||||
first_stop_position = 0
|
||||
for s in stop_sequences:
|
||||
for i in range(len(s), 0, -1):
|
||||
if all_text.endswith(s[:i]):
|
||||
if i > longest:
|
||||
longest = i
|
||||
if i > first_stop_position:
|
||||
first_stop_position = i
|
||||
break
|
||||
|
||||
offset = 0
|
||||
token_end_position = 0
|
||||
remaining_tokens = completion_tokens[returned_tokens:]
|
||||
remaining_length = len(self.detokenize(remaining_tokens))
|
||||
for token in remaining_tokens:
|
||||
offset += len(self.detokenize([token]))
|
||||
# Check if stop sequence is not in the token
|
||||
if offset >= (remaining_length - longest - 1):
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position >= (remaining_length - first_stop_position - 1):
|
||||
break
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
token_str = self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_offset = len(prompt) + len(
|
||||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens
|
||||
logits = self.eval_logits[token_offset - 1]
|
||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(current_logprobs, range(len(current_logprobs))),
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||
logprobs_or_none = {
|
||||
"tokens": [
|
||||
self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
],
|
||||
"text_offset": [text_offset],
|
||||
"token_logprobs": [sorted_logprobs[int(token)][0]],
|
||||
"top_logprobs": [top_logprob],
|
||||
}
|
||||
returned_tokens += 1
|
||||
yield {
|
||||
"id": completion_id,
|
||||
|
@ -738,7 +772,7 @@ class Llama:
|
|||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
|
@ -766,13 +800,48 @@ class Llama:
|
|||
else:
|
||||
end = len(all_text)
|
||||
|
||||
offset = 0
|
||||
token_end_position = 0
|
||||
for token in remaining_tokens:
|
||||
offset += len(self.detokenize([token]))
|
||||
if offset >= end:
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
token_str = self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_offset = len(prompt) + len(
|
||||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||
logits = self.eval_logits[token_offset]
|
||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
zip(current_logprobs, range(len(current_logprobs))),
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
top_logprob = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||
logprobs_or_none = {
|
||||
"tokens": [
|
||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
],
|
||||
"text_offset": [text_offset],
|
||||
"token_logprobs": [sorted_logprobs[int(token)][0]],
|
||||
"top_logprobs": [top_logprob],
|
||||
}
|
||||
|
||||
if token_end_position >= end:
|
||||
last_text = self.detokenize([token])
|
||||
if offset == end - 1:
|
||||
if token_end_position == end - 1:
|
||||
break
|
||||
returned_tokens += 1
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
@ -781,10 +850,10 @@ class Llama:
|
|||
"choices": [
|
||||
{
|
||||
"text": last_text[
|
||||
: len(last_text) - (offset - end)
|
||||
: len(last_text) - (token_end_position - end)
|
||||
].decode("utf-8", errors="ignore"),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
|
@ -802,7 +871,7 @@ class Llama:
|
|||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": finish_reason
|
||||
if returned_tokens == len(completion_tokens)
|
||||
else None,
|
||||
|
@ -821,13 +890,19 @@ class Llama:
|
|||
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
if logprobs is not None:
|
||||
text_offset = 0
|
||||
text_offset = 0 if echo else len(prompt)
|
||||
token_offset = 0 if echo else len(prompt_tokens[1:])
|
||||
text_offsets: List[int] = []
|
||||
token_logprobs: List[float] = []
|
||||
token_logprobs: List[Optional[float]] = []
|
||||
tokens: List[str] = []
|
||||
top_logprobs: List[Dict[str, float]] = []
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = []
|
||||
|
||||
if echo:
|
||||
# Remove leading BOS token
|
||||
all_tokens = prompt_tokens[1:] + completion_tokens
|
||||
else:
|
||||
all_tokens = completion_tokens
|
||||
|
||||
all_tokens = prompt_tokens + completion_tokens
|
||||
all_token_strs = [
|
||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
|
@ -835,7 +910,7 @@ class Llama:
|
|||
all_logprobs = [
|
||||
Llama.logits_to_logprobs(list(map(float, row)))
|
||||
for row in self.eval_logits
|
||||
]
|
||||
][token_offset:]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
|
@ -848,14 +923,20 @@ class Llama:
|
|||
)
|
||||
)
|
||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||
top_logprob = {
|
||||
top_logprob: Optional[Dict[str, float]] = {
|
||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
): logprob
|
||||
for logprob, i in sorted_logprobs[:logprobs]
|
||||
}
|
||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||
top_logprobs.append(top_logprob)
|
||||
# Weird idosincracy of the OpenAI API where
|
||||
# token_logprobs and top_logprobs are null for
|
||||
# the first token.
|
||||
if echo and len(all_tokens) > 0:
|
||||
token_logprobs[0] = None
|
||||
top_logprobs[0] = None
|
||||
logprobs_or_none = {
|
||||
"tokens": tokens,
|
||||
"text_offset": text_offsets,
|
||||
|
|
Loading…
Reference in a new issue