Fix logits_to_logprobs for 2-D and 3-D logits (#1002)

* Fix logits_to_logprobs for 2-D and 3-D logits

* Set dtype to single

* Test size
This commit is contained in:
kddubey 2023-12-16 15:59:26 -08:00 committed by GitHub
parent 534b1ea9b5
commit 5a8944672f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 10 deletions

View file

@ -771,7 +771,7 @@ class Llama:
**kwargs, # type: ignore
):
"""Load a llama.cpp model from `model_path`.
Examples:
Basic usage
@ -2280,14 +2280,22 @@ class Llama:
return self._model.token_nl()
@staticmethod
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
maximum = np.max(logits)
tmp = np.subtract(logits, maximum, dtype=np.single)
np.exp(tmp, out=tmp)
normalizer = 1.0 / np.sum(tmp)
np.multiply(normalizer, tmp, out=tmp)
np.log(tmp, out=tmp)
return tmp
def logits_to_logprobs(
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
) -> npt.NDArray[np.single]:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
if logits_maxs.ndim > 0:
logits_maxs[~np.isfinite(logits_maxs)] = 0
elif not np.isfinite(logits_maxs):
logits_maxs = 0
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
exp = np.exp(subtract_maxs)
# Suppress warnings about log of zero
with np.errstate(divide='ignore'):
summed = np.sum(exp, axis=axis, keepdims=True)
out = np.log(summed)
return subtract_maxs - out
@staticmethod
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):

View file

@ -33,11 +33,12 @@ server = [
"fastapi>=0.100.0",
"pydantic-settings>=2.0.1",
"sse-starlette>=1.6.1",
"starlette-context>=0.3.6,<0.4"
"starlette-context>=0.3.6,<0.4",
]
test = [
"pytest>=7.4.0",
"httpx>=0.24.1",
"scipy>=1.10",
]
dev = [
"black>=23.3.0",

View file

@ -1,6 +1,8 @@
import ctypes
import numpy as np
import pytest
from scipy.special import log_softmax
import llama_cpp
@ -264,5 +266,28 @@ def test_llama_server():
}
@pytest.mark.parametrize(
"size_and_axis",
[
((32_000,), -1), # last token's next-token logits
((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens
((4, 10, 32_000), -1), # batch of texts
],
)
@pytest.mark.parametrize("convert_to_list", [True, False])
def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7):
size, axis = size_and_axis
logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size)
logits = logits.astype(np.single)
if convert_to_list:
# Currently, logits are converted from arrays to lists. This may change soon
logits = logits.tolist()
log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis)
log_probs_correct = log_softmax(logits, axis=axis)
assert log_probs.dtype == np.single
assert log_probs.shape == size
assert np.allclose(log_probs, log_probs_correct, atol=atol)
def test_llama_cpp_version():
assert llama_cpp.__version__