Added types to logit processor list and stop criteria list

This commit is contained in:
Maximilian-Winter 2023-05-25 09:07:16 +02:00
parent c05fcdf42f
commit da463e6c8c

View file

@ -4,7 +4,7 @@ import uuid
import time import time
import math import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
from collections import deque, OrderedDict from collections import deque, OrderedDict
from . import llama_cpp from . import llama_cpp
@ -316,12 +316,11 @@ class Llama:
mirostat_tau: llama_cpp.c_float, mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processors=None logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
if logits_processors is None:
if logits_processors == None:
logits_processors = [] logits_processors = []
n_vocab = self.n_vocab() n_vocab = self.n_vocab()
@ -445,7 +444,7 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processors=None logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -497,7 +496,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processors=None logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
) -> Generator[int, Optional[Sequence[int]], None]: ) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
@ -652,12 +651,12 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processors=None, logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
stopping_criterias=None stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
if stopping_criterias == None: if stopping_criterias is None:
stopping_criterias = [] stopping_criterias = []
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@ -1036,8 +1035,8 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processors=None, logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
stopping_criterias=None stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.