diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f4e5dcd..25abf36 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys import uuid @@ -40,46 +42,6 @@ from ._internals import ( ) -class LlamaState: - def __init__( - self, - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - n_tokens: int, - llama_state: bytes, - llama_state_size: int, - ): - self.input_ids = input_ids - self.scores = scores - self.n_tokens = n_tokens - self.llama_state = llama_state - self.llama_state_size = llama_state_size - - -LogitsProcessor = Callable[ - [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] -] - - -class LogitsProcessorList(List[LogitsProcessor]): - def __call__( - self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] - ) -> npt.NDArray[np.single]: - for processor in self: - scores = processor(input_ids, scores) - return scores - - -StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] - - -class StoppingCriteriaList(List[StoppingCriteria]): - def __call__( - self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] - ) -> bool: - return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) - - class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -1733,3 +1695,43 @@ class LlamaTokenizer: @classmethod def from_ggml_file(cls, path: str) -> "LlamaTokenizer": return cls(Llama(model_path=path, vocab_only=True)) + + +class LlamaState: + def __init__( + self, + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + n_tokens: int, + llama_state: bytes, + llama_state_size: int, + ): + self.input_ids = input_ids + self.scores = scores + self.n_tokens = n_tokens + self.llama_state = llama_state + self.llama_state_size = llama_state_size + + +LogitsProcessor = Callable[ + [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] +] + + +class LogitsProcessorList(List[LogitsProcessor]): + def __call__( + self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] + ) -> npt.NDArray[np.single]: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] + + +class StoppingCriteriaList(List[StoppingCriteria]): + def __call__( + self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] + ) -> bool: + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])