Fix logits_to_logprobs for 2-D and 3-D logits (#1002)
* Fix logits_to_logprobs for 2-D and 3-D logits * Set dtype to single * Test size
This commit is contained in:
parent
534b1ea9b5
commit
5a8944672f
3 changed files with 44 additions and 10 deletions
|
@ -771,7 +771,7 @@ class Llama:
|
|||
**kwargs, # type: ignore
|
||||
):
|
||||
"""Load a llama.cpp model from `model_path`.
|
||||
|
||||
|
||||
Examples:
|
||||
Basic usage
|
||||
|
||||
|
@ -2280,14 +2280,22 @@ class Llama:
|
|||
return self._model.token_nl()
|
||||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
|
||||
maximum = np.max(logits)
|
||||
tmp = np.subtract(logits, maximum, dtype=np.single)
|
||||
np.exp(tmp, out=tmp)
|
||||
normalizer = 1.0 / np.sum(tmp)
|
||||
np.multiply(normalizer, tmp, out=tmp)
|
||||
np.log(tmp, out=tmp)
|
||||
return tmp
|
||||
def logits_to_logprobs(
|
||||
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
|
||||
) -> npt.NDArray[np.single]:
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
|
||||
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
|
||||
if logits_maxs.ndim > 0:
|
||||
logits_maxs[~np.isfinite(logits_maxs)] = 0
|
||||
elif not np.isfinite(logits_maxs):
|
||||
logits_maxs = 0
|
||||
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
|
||||
exp = np.exp(subtract_maxs)
|
||||
# Suppress warnings about log of zero
|
||||
with np.errstate(divide='ignore'):
|
||||
summed = np.sum(exp, axis=axis, keepdims=True)
|
||||
out = np.log(summed)
|
||||
return subtract_maxs - out
|
||||
|
||||
@staticmethod
|
||||
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
|
||||
|
|
|
@ -33,11 +33,12 @@ server = [
|
|||
"fastapi>=0.100.0",
|
||||
"pydantic-settings>=2.0.1",
|
||||
"sse-starlette>=1.6.1",
|
||||
"starlette-context>=0.3.6,<0.4"
|
||||
"starlette-context>=0.3.6,<0.4",
|
||||
]
|
||||
test = [
|
||||
"pytest>=7.4.0",
|
||||
"httpx>=0.24.1",
|
||||
"scipy>=1.10",
|
||||
]
|
||||
dev = [
|
||||
"black>=23.3.0",
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import ctypes
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from scipy.special import log_softmax
|
||||
|
||||
import llama_cpp
|
||||
|
||||
|
@ -264,5 +266,28 @@ def test_llama_server():
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size_and_axis",
|
||||
[
|
||||
((32_000,), -1), # last token's next-token logits
|
||||
((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens
|
||||
((4, 10, 32_000), -1), # batch of texts
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("convert_to_list", [True, False])
|
||||
def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7):
|
||||
size, axis = size_and_axis
|
||||
logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size)
|
||||
logits = logits.astype(np.single)
|
||||
if convert_to_list:
|
||||
# Currently, logits are converted from arrays to lists. This may change soon
|
||||
logits = logits.tolist()
|
||||
log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis)
|
||||
log_probs_correct = log_softmax(logits, axis=axis)
|
||||
assert log_probs.dtype == np.single
|
||||
assert log_probs.shape == size
|
||||
assert np.allclose(log_probs, log_probs_correct, atol=atol)
|
||||
|
||||
|
||||
def test_llama_cpp_version():
|
||||
assert llama_cpp.__version__
|
||||
|
|
Loading…
Reference in a new issue