Merge pull request #453 from wu-qing-157/main

Fix incorrect token_logprobs (due to indexing after sorting)
This commit is contained in:
Andrei 2023-07-08 02:31:52 -04:00 committed by GitHub
commit b8e0bed295
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -964,7 +964,7 @@ class Llama:
) )
], ],
"text_offset": [text_offset], "text_offset": [text_offset],
"token_logprobs": [sorted_logprobs[int(token)][0]], "token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob], "top_logprobs": [top_logprob],
} }
returned_tokens += 1 returned_tokens += 1
@ -1039,7 +1039,7 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore") self.detokenize([token]).decode("utf-8", errors="ignore")
], ],
"text_offset": [text_offset], "text_offset": [text_offset],
"token_logprobs": [sorted_logprobs[int(token)][0]], "token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob], "top_logprobs": [top_logprob],
} }
@ -1163,7 +1163,7 @@ class Llama:
zip(logprobs_token, range(len(logprobs_token))), reverse=True zip(logprobs_token, range(len(logprobs_token))), reverse=True
) )
) )
token_logprobs.append(sorted_logprobs[int(token)][0]) token_logprobs.append(logprobs_token[int(token)])
top_logprob: Optional[Dict[str, float]] = { top_logprob: Optional[Dict[str, float]] = {
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs] for logprob, i in sorted_logprobs[:logprobs]