Re-order classes in llama.py
This commit is contained in:
parent
cc4630e66f
commit
7b46bb5a78
1 changed files with 42 additions and 40 deletions
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -40,46 +42,6 @@ from ._internals import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaState:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_ids: npt.NDArray[np.intc],
|
|
||||||
scores: npt.NDArray[np.single],
|
|
||||||
n_tokens: int,
|
|
||||||
llama_state: bytes,
|
|
||||||
llama_state_size: int,
|
|
||||||
):
|
|
||||||
self.input_ids = input_ids
|
|
||||||
self.scores = scores
|
|
||||||
self.n_tokens = n_tokens
|
|
||||||
self.llama_state = llama_state
|
|
||||||
self.llama_state_size = llama_state_size
|
|
||||||
|
|
||||||
|
|
||||||
LogitsProcessor = Callable[
|
|
||||||
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessorList(List[LogitsProcessor]):
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
|
||||||
) -> npt.NDArray[np.single]:
|
|
||||||
for processor in self:
|
|
||||||
scores = processor(input_ids, scores)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
|
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteriaList(List[StoppingCriteria]):
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
|
||||||
) -> bool:
|
|
||||||
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
"""High-level Python wrapper for a llama.cpp model."""
|
"""High-level Python wrapper for a llama.cpp model."""
|
||||||
|
|
||||||
|
@ -1733,3 +1695,43 @@ class LlamaTokenizer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
||||||
return cls(Llama(model_path=path, vocab_only=True))
|
return cls(Llama(model_path=path, vocab_only=True))
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaState:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_ids: npt.NDArray[np.intc],
|
||||||
|
scores: npt.NDArray[np.single],
|
||||||
|
n_tokens: int,
|
||||||
|
llama_state: bytes,
|
||||||
|
llama_state_size: int,
|
||||||
|
):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.scores = scores
|
||||||
|
self.n_tokens = n_tokens
|
||||||
|
self.llama_state = llama_state
|
||||||
|
self.llama_state_size = llama_state_size
|
||||||
|
|
||||||
|
|
||||||
|
LogitsProcessor = Callable[
|
||||||
|
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LogitsProcessorList(List[LogitsProcessor]):
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
||||||
|
) -> npt.NDArray[np.single]:
|
||||||
|
for processor in self:
|
||||||
|
scores = processor(input_ids, scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteriaList(List[StoppingCriteria]):
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
||||||
|
) -> bool:
|
||||||
|
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
||||||
|
|
Loading…
Reference in a new issue