Added types to logit processor list and stop criteria list
This commit is contained in:
parent
c05fcdf42f
commit
da463e6c8c
1 changed files with 10 additions and 11 deletions
|
@ -4,7 +4,7 @@ import uuid
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
|
||||||
from collections import deque, OrderedDict
|
from collections import deque, OrderedDict
|
||||||
|
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
|
@ -316,12 +316,11 @@ class Llama:
|
||||||
mirostat_tau: llama_cpp.c_float,
|
mirostat_tau: llama_cpp.c_float,
|
||||||
mirostat_eta: llama_cpp.c_float,
|
mirostat_eta: llama_cpp.c_float,
|
||||||
penalize_nl: bool = True,
|
penalize_nl: bool = True,
|
||||||
logits_processors=None
|
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
|
if logits_processors is None:
|
||||||
if logits_processors == None:
|
|
||||||
logits_processors = []
|
logits_processors = []
|
||||||
|
|
||||||
n_vocab = self.n_vocab()
|
n_vocab = self.n_vocab()
|
||||||
|
@ -445,7 +444,7 @@ class Llama:
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
penalize_nl: bool = True,
|
penalize_nl: bool = True,
|
||||||
logits_processors=None
|
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
|
||||||
|
|
||||||
):
|
):
|
||||||
"""Sample a token from the model.
|
"""Sample a token from the model.
|
||||||
|
@ -497,7 +496,7 @@ class Llama:
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
logits_processors=None
|
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None
|
||||||
) -> Generator[int, Optional[Sequence[int]], None]:
|
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
|
@ -652,12 +651,12 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processors=None,
|
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
|
||||||
stopping_criterias=None
|
stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None,
|
||||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
if stopping_criterias == None:
|
if stopping_criterias is None:
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
|
@ -1036,8 +1035,8 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processors=None,
|
logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None,
|
||||||
stopping_criterias=None
|
stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue