Bugfix: Missing logits_to_logprobs

This commit is contained in:
Andrei Betlen 2023-05-04 12:18:40 -04:00
parent d594892fd4
commit 329297fafb

View file

@ -639,7 +639,7 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
@ -985,7 +985,7 @@ class Llama:
return llama_cpp.llama_token_bos()
@staticmethod
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
def logits_to_logprobs(logits: List[float]) -> List[float]:
exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps)
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
return [math.log(x / sum_exps) for x in exps]