Bugfix: Missing logits_to_logprobs
This commit is contained in:
parent
d594892fd4
commit
329297fafb
1 changed files with 3 additions and 3 deletions
|
@ -639,7 +639,7 @@ class Llama:
|
||||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
|
all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits]
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
):
|
):
|
||||||
|
@ -985,7 +985,7 @@ class Llama:
|
||||||
return llama_cpp.llama_token_bos()
|
return llama_cpp.llama_token_bos()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
|
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
||||||
exps = [math.exp(float(x)) for x in logits]
|
exps = [math.exp(float(x)) for x in logits]
|
||||||
sum_exps = sum(exps)
|
sum_exps = sum(exps)
|
||||||
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
|
return [math.log(x / sum_exps) for x in exps]
|
||||||
|
|
Loading…
Reference in a new issue