Use numpy arrays for logits_processors and stopping_criteria. Closes #491

This commit is contained in:
Andrei Betlen 2023-07-18 19:27:41 -04:00
parent 5eab1db0d0
commit 19ba9d3845
2 changed files with 24 additions and 16 deletions

View file

@ -27,6 +27,7 @@ from .llama_types import *
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
class BaseLlamaCache(ABC): class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model.""" """Base cache class for a llama.cpp model."""
@ -179,21 +180,27 @@ class LlamaState:
self.llama_state_size = llama_state_size self.llama_state_size = llama_state_size
LogitsProcessor = Callable[[List[int], List[float]], List[float]] LogitsProcessor = Callable[
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
]
class LogitsProcessorList(List[LogitsProcessor]): class LogitsProcessorList(List[LogitsProcessor]):
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]: def __call__(
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
) -> npt.NDArray[np.single]:
for processor in self: for processor in self:
scores = processor(input_ids, scores) scores = processor(input_ids, scores)
return scores return scores
StoppingCriteria = Callable[[List[int], List[float]], bool] StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
class StoppingCriteriaList(List[StoppingCriteria]): class StoppingCriteriaList(List[StoppingCriteria]):
def __call__(self, input_ids: List[int], logits: List[float]) -> bool: def __call__(
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
@ -276,7 +283,9 @@ class Llama:
if self.tensor_split is not None: if self.tensor_split is not None:
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd self._c_tensor_split = FloatArray(
*tensor_split
) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._c_tensor_split self.params.tensor_split = self._c_tensor_split
self.params.rope_freq_base = rope_freq_base self.params.rope_freq_base = rope_freq_base
@ -503,11 +512,7 @@ class Llama:
logits: npt.NDArray[np.single] = self._scores[-1, :] logits: npt.NDArray[np.single] = self._scores[-1, :]
if logits_processor is not None: if logits_processor is not None:
logits = np.array( logits[:] = logits_processor(self._input_ids, logits)
logits_processor(self._input_ids.tolist(), logits.tolist()),
dtype=np.single,
)
self._scores[-1, :] = logits
nl_logit = logits[self._token_nl] nl_logit = logits[self._token_nl]
candidates = self._candidates candidates = self._candidates
@ -725,7 +730,7 @@ class Llama:
logits_processor=logits_processor, logits_processor=logits_processor,
) )
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids.tolist(), self._scores[-1, :].tolist() self._input_ids, self._scores[-1, :]
): ):
return return
tokens_or_none = yield token tokens_or_none = yield token
@ -1014,7 +1019,7 @@ class Llama:
break break
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids.tolist(), self._scores[-1, :].tolist() self._input_ids, self._scores[-1, :]
): ):
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "stop" finish_reason = "stop"

View file

@ -16,6 +16,9 @@ from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
import numpy as np
import numpy.typing as npt
class Settings(BaseSettings): class Settings(BaseSettings):
model: str = Field( model: str = Field(
@ -336,9 +339,9 @@ def make_logit_bias_processor(
to_bias[input_id] = score to_bias[input_id] = score
def logit_bias_processor( def logit_bias_processor(
input_ids: List[int], input_ids: npt.NDArray[np.intc],
scores: List[float], scores: npt.NDArray[np.single],
) -> List[float]: ) -> npt.NDArray[np.single]:
new_scores = [None] * len(scores) new_scores = [None] * len(scores)
for input_id, score in enumerate(scores): for input_id, score in enumerate(scores):
new_scores[input_id] = score + to_bias.get(input_id, 0.0) new_scores[input_id] = score + to_bias.get(input_id, 0.0)