Added types to logit processor list and stop criteria list

This commit is contained in:
Maximilian-Winter 2023-05-25 09:07:16 +02:00
parent c05fcdf42f
commit da463e6c8c

View file

@ -4,7 +4,7 @@ import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
from collections import deque, OrderedDict
from . import llama_cpp
@ -316,12 +316,11 @@ class Llama:
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
logits_processors=None
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
if logits_processors == None:
if logits_processors is None:
logits_processors = []
n_vocab = self.n_vocab()
@ -445,7 +444,7 @@ class Llama:
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processors=None
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
):
"""Sample a token from the model.
@ -497,7 +496,7 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
logits_processors=None
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
@ -652,12 +651,12 @@ class Llama:
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processors=None,
stopping_criterias=None
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
if stopping_criterias == None:
if stopping_criterias is None:
stopping_criterias = []
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@ -1036,8 +1035,8 @@ class Llama:
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processors=None,
stopping_criterias=None
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.