Fix logprob calculation. Fixes #134
This commit is contained in:
parent
c088a2b3a7
commit
b6747f722e
1 changed files with 5 additions and 3 deletions
|
@ -638,7 +638,7 @@ class Llama:
|
|||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
[Llama.logit_to_logprob(logit) for logit in row]
|
||||
Llama._logits_to_logprobs(row)
|
||||
for row in self.eval_logits
|
||||
]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
|
@ -980,5 +980,7 @@ class Llama:
|
|||
return llama_cpp.llama_token_bos()
|
||||
|
||||
@staticmethod
|
||||
def logit_to_logprob(x: float) -> float:
|
||||
return math.log(1.0 + math.exp(x))
|
||||
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
|
||||
exps = [math.exp(float(x)) for x in logits]
|
||||
sum_exps = sum(exps)
|
||||
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
|
||||
|
|
Loading…
Reference in a new issue