Bugfix: Missing logits_to_logprobs
This commit is contained in:
parent
d594892fd4
commit
329297fafb
1 changed files with 3 additions and 3 deletions
|
@ -639,7 +639,7 @@ class Llama:
|
|||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
|
||||
all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
|
@ -985,7 +985,7 @@ class Llama:
|
|||
return llama_cpp.llama_token_bos()
|
||||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
|
||||
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
||||
exps = [math.exp(float(x)) for x in logits]
|
||||
sum_exps = sum(exps)
|
||||
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
|
||||
return [math.log(x / sum_exps) for x in exps]
|
||||
|
|
Loading…
Reference in a new issue