Fix logits_to_logprobs for 2-D and 3-D logits (#1002)

* Fix logits_to_logprobs for 2-D and 3-D logits

* Set dtype to single

* Test size
This commit is contained in:
kddubey 2023-12-16 15:59:26 -08:00 committed by GitHub
parent 534b1ea9b5
commit 5a8944672f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 10 deletions

View file

@ -2280,14 +2280,22 @@ class Llama:
return self._model.token_nl() return self._model.token_nl()
@staticmethod @staticmethod
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]: def logits_to_logprobs(
maximum = np.max(logits) logits: Union[List, npt.NDArray[np.single]], axis: int = -1
tmp = np.subtract(logits, maximum, dtype=np.single) ) -> npt.NDArray[np.single]:
np.exp(tmp, out=tmp) # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
normalizer = 1.0 / np.sum(tmp) logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
np.multiply(normalizer, tmp, out=tmp) if logits_maxs.ndim > 0:
np.log(tmp, out=tmp) logits_maxs[~np.isfinite(logits_maxs)] = 0
return tmp elif not np.isfinite(logits_maxs):
logits_maxs = 0
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
exp = np.exp(subtract_maxs)
# Suppress warnings about log of zero
with np.errstate(divide='ignore'):
summed = np.sum(exp, axis=axis, keepdims=True)
out = np.log(summed)
return subtract_maxs - out
@staticmethod @staticmethod
def longest_token_prefix(a: Sequence[int], b: Sequence[int]): def longest_token_prefix(a: Sequence[int], b: Sequence[int]):

View file

@ -33,11 +33,12 @@ server = [
"fastapi>=0.100.0", "fastapi>=0.100.0",
"pydantic-settings>=2.0.1", "pydantic-settings>=2.0.1",
"sse-starlette>=1.6.1", "sse-starlette>=1.6.1",
"starlette-context>=0.3.6,<0.4" "starlette-context>=0.3.6,<0.4",
] ]
test = [ test = [
"pytest>=7.4.0", "pytest>=7.4.0",
"httpx>=0.24.1", "httpx>=0.24.1",
"scipy>=1.10",
] ]
dev = [ dev = [
"black>=23.3.0", "black>=23.3.0",

View file

@ -1,6 +1,8 @@
import ctypes import ctypes
import numpy as np
import pytest import pytest
from scipy.special import log_softmax
import llama_cpp import llama_cpp
@ -264,5 +266,28 @@ def test_llama_server():
} }
@pytest.mark.parametrize(
"size_and_axis",
[
((32_000,), -1), # last token's next-token logits
((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens
((4, 10, 32_000), -1), # batch of texts
],
)
@pytest.mark.parametrize("convert_to_list", [True, False])
def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7):
size, axis = size_and_axis
logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size)
logits = logits.astype(np.single)
if convert_to_list:
# Currently, logits are converted from arrays to lists. This may change soon
logits = logits.tolist()
log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis)
log_probs_correct = log_softmax(logits, axis=axis)
assert log_probs.dtype == np.single
assert log_probs.shape == size
assert np.allclose(log_probs, log_probs_correct, atol=atol)
def test_llama_cpp_version(): def test_llama_cpp_version():
assert llama_cpp.__version__ assert llama_cpp.__version__