From da463e6c8c3c09c7a32bf25d924974d74f3d2776 Mon Sep 17 00:00:00 2001 From: Maximilian-Winter Date: Thu, 25 May 2023 09:07:16 +0200 Subject: [PATCH] Added types to logit processor list and stop criteria list --- llama_cpp/llama.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c6f540c..8176136 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,7 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple +from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable from collections import deque, OrderedDict from . import llama_cpp @@ -316,12 +316,11 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, - logits_processors=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None ): assert self.ctx is not None assert len(self.eval_logits) > 0 - - if logits_processors == None: + if logits_processors is None: logits_processors = [] n_vocab = self.n_vocab() @@ -445,7 +444,7 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, - logits_processors=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None ): """Sample a token from the model. @@ -497,7 +496,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - logits_processors=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -652,12 +651,12 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors=None, - stopping_criterias=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, + stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None - if stopping_criterias == None: + if stopping_criterias is None: stopping_criterias = [] completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -1036,8 +1035,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors=None, - stopping_criterias=None + logits_processors: List[Callable[[List[llama_cpp.c_int], List[llama_cpp.c_float]], List[float]]] = None, + stopping_criterias: List[Callable[[List[int], List[llama_cpp.c_float]], bool]] = None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt.