Fix logprob calculation. Fixes #134
This commit is contained in:
parent
c088a2b3a7
commit
b6747f722e
1 changed files with 5 additions and 3 deletions
|
@ -638,7 +638,7 @@ class Llama:
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
[Llama.logit_to_logprob(logit) for logit in row]
|
Llama._logits_to_logprobs(row)
|
||||||
for row in self.eval_logits
|
for row in self.eval_logits
|
||||||
]
|
]
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
|
@ -980,5 +980,7 @@ class Llama:
|
||||||
return llama_cpp.llama_token_bos()
|
return llama_cpp.llama_token_bos()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logit_to_logprob(x: float) -> float:
|
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
|
||||||
return math.log(1.0 + math.exp(x))
|
exps = [math.exp(float(x)) for x in logits]
|
||||||
|
sum_exps = sum(exps)
|
||||||
|
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
|
||||||
|
|
Loading…
Add table
Reference in a new issue