diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 292378d..d5cf401 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -771,7 +771,7 @@ class Llama: **kwargs, # type: ignore ): """Load a llama.cpp model from `model_path`. - + Examples: Basic usage @@ -2280,14 +2280,22 @@ class Llama: return self._model.token_nl() @staticmethod - def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]: - maximum = np.max(logits) - tmp = np.subtract(logits, maximum, dtype=np.single) - np.exp(tmp, out=tmp) - normalizer = 1.0 / np.sum(tmp) - np.multiply(normalizer, tmp, out=tmp) - np.log(tmp, out=tmp) - return tmp + def logits_to_logprobs( + logits: Union[List, npt.NDArray[np.single]], axis: int = -1 + ) -> npt.NDArray[np.single]: + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html + logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True) + if logits_maxs.ndim > 0: + logits_maxs[~np.isfinite(logits_maxs)] = 0 + elif not np.isfinite(logits_maxs): + logits_maxs = 0 + subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single) + exp = np.exp(subtract_maxs) + # Suppress warnings about log of zero + with np.errstate(divide='ignore'): + summed = np.sum(exp, axis=axis, keepdims=True) + out = np.log(summed) + return subtract_maxs - out @staticmethod def longest_token_prefix(a: Sequence[int], b: Sequence[int]): diff --git a/pyproject.toml b/pyproject.toml index 6c10225..b5affaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,12 @@ server = [ "fastapi>=0.100.0", "pydantic-settings>=2.0.1", "sse-starlette>=1.6.1", - "starlette-context>=0.3.6,<0.4" + "starlette-context>=0.3.6,<0.4", ] test = [ "pytest>=7.4.0", "httpx>=0.24.1", + "scipy>=1.10", ] dev = [ "black>=23.3.0", diff --git a/tests/test_llama.py b/tests/test_llama.py index c98148e..dac33b7 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,6 +1,8 @@ import ctypes +import numpy as np import pytest +from scipy.special import log_softmax import llama_cpp @@ -264,5 +266,28 @@ def test_llama_server(): } +@pytest.mark.parametrize( + "size_and_axis", + [ + ((32_000,), -1), # last token's next-token logits + ((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens + ((4, 10, 32_000), -1), # batch of texts + ], +) +@pytest.mark.parametrize("convert_to_list", [True, False]) +def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7): + size, axis = size_and_axis + logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size) + logits = logits.astype(np.single) + if convert_to_list: + # Currently, logits are converted from arrays to lists. This may change soon + logits = logits.tolist() + log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis) + log_probs_correct = log_softmax(logits, axis=axis) + assert log_probs.dtype == np.single + assert log_probs.shape == size + assert np.allclose(log_probs, log_probs_correct, atol=atol) + + def test_llama_cpp_version(): assert llama_cpp.__version__