Replace logits_to_logprobs implementation with numpy equivalent to llama.cpp (#991)
See #990. This change makes the logits_to_logprobs function equivalent to the version in the llama.cpp repository. It uses numpy so it's much faster than the previous version.
This commit is contained in:
parent
ac35f68e4d
commit
ef22e478db
1 changed files with 8 additions and 4 deletions
|
@ -2280,10 +2280,14 @@ class Llama:
|
|||
return self._model.token_nl()
|
||||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
||||
exps = [math.exp(float(x)) for x in logits]
|
||||
sum_exps = sum(exps)
|
||||
return [math.log(x / sum_exps) for x in exps]
|
||||
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
|
||||
maximum = np.max(logits)
|
||||
tmp = np.subtract(logits, maximum, dtype=np.single)
|
||||
np.exp(tmp, out=tmp)
|
||||
normalizer = 1.0 / np.sum(tmp)
|
||||
np.multiply(normalizer, tmp, out=tmp)
|
||||
np.log(tmp, out=tmp)
|
||||
return tmp
|
||||
|
||||
@staticmethod
|
||||
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
|
||||
|
|
Loading…
Add table
Reference in a new issue