Fix logprobs for completions and implement for streaming logprobs.
This commit is contained in:
parent
a634a2453b
commit
17d4271b04
1 changed files with 103 additions and 22 deletions
|
@ -710,22 +710,56 @@ class Llama:
|
||||||
# We want to avoid yielding any characters from
|
# We want to avoid yielding any characters from
|
||||||
# the generated text if they are part of a stop
|
# the generated text if they are part of a stop
|
||||||
# sequence.
|
# sequence.
|
||||||
longest = 0
|
first_stop_position = 0
|
||||||
for s in stop_sequences:
|
for s in stop_sequences:
|
||||||
for i in range(len(s), 0, -1):
|
for i in range(len(s), 0, -1):
|
||||||
if all_text.endswith(s[:i]):
|
if all_text.endswith(s[:i]):
|
||||||
if i > longest:
|
if i > first_stop_position:
|
||||||
longest = i
|
first_stop_position = i
|
||||||
break
|
break
|
||||||
|
|
||||||
offset = 0
|
token_end_position = 0
|
||||||
remaining_tokens = completion_tokens[returned_tokens:]
|
remaining_tokens = completion_tokens[returned_tokens:]
|
||||||
remaining_length = len(self.detokenize(remaining_tokens))
|
remaining_length = len(self.detokenize(remaining_tokens))
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
offset += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token]))
|
||||||
# Check if stop sequence is not in the token
|
# Check if stop sequence is in the token
|
||||||
if offset >= (remaining_length - longest - 1):
|
if token_end_position >= (remaining_length - first_stop_position - 1):
|
||||||
break
|
break
|
||||||
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
|
if logprobs is not None:
|
||||||
|
token_str = self.detokenize([token]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
text_offset = len(prompt) + len(
|
||||||
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
|
)
|
||||||
|
token_offset = len(prompt_tokens) + returned_tokens
|
||||||
|
logits = self.eval_logits[token_offset - 1]
|
||||||
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
|
sorted_logprobs = list(
|
||||||
|
sorted(
|
||||||
|
zip(current_logprobs, range(len(current_logprobs))),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
top_logprob = {
|
||||||
|
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
): logprob
|
||||||
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
|
}
|
||||||
|
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||||
|
logprobs_or_none = {
|
||||||
|
"tokens": [
|
||||||
|
self.detokenize([token]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"text_offset": [text_offset],
|
||||||
|
"token_logprobs": [sorted_logprobs[int(token)][0]],
|
||||||
|
"top_logprobs": [top_logprob],
|
||||||
|
}
|
||||||
returned_tokens += 1
|
returned_tokens += 1
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
|
@ -738,7 +772,7 @@ class Llama:
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
),
|
),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": logprobs_or_none,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -766,13 +800,48 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
end = len(all_text)
|
end = len(all_text)
|
||||||
|
|
||||||
offset = 0
|
token_end_position = 0
|
||||||
for token in remaining_tokens:
|
for token in remaining_tokens:
|
||||||
offset += len(self.detokenize([token]))
|
token_end_position += len(self.detokenize([token]))
|
||||||
if offset >= end:
|
|
||||||
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
|
if logprobs is not None:
|
||||||
|
token_str = self.detokenize([token]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
text_offset = len(prompt) + len(
|
||||||
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
|
)
|
||||||
|
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||||
|
logits = self.eval_logits[token_offset]
|
||||||
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
|
sorted_logprobs = list(
|
||||||
|
sorted(
|
||||||
|
zip(current_logprobs, range(len(current_logprobs))),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
top_logprob = {
|
||||||
|
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
): logprob
|
||||||
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
|
}
|
||||||
|
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||||
|
logprobs_or_none = {
|
||||||
|
"tokens": [
|
||||||
|
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
|
],
|
||||||
|
"text_offset": [text_offset],
|
||||||
|
"token_logprobs": [sorted_logprobs[int(token)][0]],
|
||||||
|
"top_logprobs": [top_logprob],
|
||||||
|
}
|
||||||
|
|
||||||
|
if token_end_position >= end:
|
||||||
last_text = self.detokenize([token])
|
last_text = self.detokenize([token])
|
||||||
if offset == end - 1:
|
if token_end_position == end - 1:
|
||||||
break
|
break
|
||||||
|
returned_tokens += 1
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -781,10 +850,10 @@ class Llama:
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": last_text[
|
"text": last_text[
|
||||||
: len(last_text) - (offset - end)
|
: len(last_text) - (token_end_position - end)
|
||||||
].decode("utf-8", errors="ignore"),
|
].decode("utf-8", errors="ignore"),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": logprobs_or_none,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -802,7 +871,7 @@ class Llama:
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
),
|
),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": logprobs_or_none,
|
||||||
"finish_reason": finish_reason
|
"finish_reason": finish_reason
|
||||||
if returned_tokens == len(completion_tokens)
|
if returned_tokens == len(completion_tokens)
|
||||||
else None,
|
else None,
|
||||||
|
@ -821,13 +890,19 @@ class Llama:
|
||||||
|
|
||||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
text_offset = 0
|
text_offset = 0 if echo else len(prompt)
|
||||||
|
token_offset = 0 if echo else len(prompt_tokens[1:])
|
||||||
text_offsets: List[int] = []
|
text_offsets: List[int] = []
|
||||||
token_logprobs: List[float] = []
|
token_logprobs: List[Optional[float]] = []
|
||||||
tokens: List[str] = []
|
tokens: List[str] = []
|
||||||
top_logprobs: List[Dict[str, float]] = []
|
top_logprobs: List[Optional[Dict[str, float]]] = []
|
||||||
|
|
||||||
|
if echo:
|
||||||
|
# Remove leading BOS token
|
||||||
|
all_tokens = prompt_tokens[1:] + completion_tokens
|
||||||
|
else:
|
||||||
|
all_tokens = completion_tokens
|
||||||
|
|
||||||
all_tokens = prompt_tokens + completion_tokens
|
|
||||||
all_token_strs = [
|
all_token_strs = [
|
||||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
|
@ -835,7 +910,7 @@ class Llama:
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
Llama.logits_to_logprobs(list(map(float, row)))
|
Llama.logits_to_logprobs(list(map(float, row)))
|
||||||
for row in self.eval_logits
|
for row in self.eval_logits
|
||||||
]
|
][token_offset:]
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
):
|
):
|
||||||
|
@ -848,14 +923,20 @@ class Llama:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
top_logprob = {
|
top_logprob: Optional[Dict[str, float]] = {
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
self.detokenize([llama_cpp.llama_token(i)]).decode(
|
||||||
"utf-8", errors="ignore"
|
"utf-8", errors="ignore"
|
||||||
): logprob
|
): logprob
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
}
|
}
|
||||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
top_logprob.update({token_str: logprobs_token[int(token)]})
|
||||||
top_logprobs.append(top_logprob)
|
top_logprobs.append(top_logprob)
|
||||||
|
# Weird idosincracy of the OpenAI API where
|
||||||
|
# token_logprobs and top_logprobs are null for
|
||||||
|
# the first token.
|
||||||
|
if echo and len(all_tokens) > 0:
|
||||||
|
token_logprobs[0] = None
|
||||||
|
top_logprobs[0] = None
|
||||||
logprobs_or_none = {
|
logprobs_or_none = {
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"text_offset": text_offsets,
|
"text_offset": text_offsets,
|
||||||
|
|
Loading…
Reference in a new issue