From 66fb0345e8ade77d8eb3ae8103d1954e9a88d7ef Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 8 Aug 2023 15:08:54 -0400 Subject: [PATCH] Move grammar to function call argument --- llama_cpp/llama.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 56143c9..a996d5c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -227,7 +227,6 @@ class Llama: tensor_split: Optional[List[float]] = None, rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, - grammar: Optional[Union[str, Path]] = None, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b rms_norm_eps: Optional[float] = None, # (TEMPORARY) mul_mat_q: Optional[bool] = None, # (TEMPORARY) @@ -254,7 +253,6 @@ class Llama: tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. rope_freq_base: Base frequency for rope sampling. rope_freq_scale: Scale factor for rope sampling. - grammar: Path to a BNF grammar file to use for grammar based sampling. verbose: Print verbose output to stderr. Raises: @@ -383,12 +381,6 @@ class Llama: self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx, self._n_vocab), dtype=np.single ) - if grammar is not None: - self.grammar = LlamaGrammar.from_file( - grammar, verbose=verbose - ) # type: Optional[LlamaGrammar] - else: - self.grammar = None @property def _input_ids(self) -> npt.NDArray[np.intc]: @@ -527,6 +519,7 @@ class Llama: mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ): assert self.ctx is not None assert self.n_tokens > 0 @@ -574,11 +567,11 @@ class Llama: if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) - if self.grammar is not None: + if grammar is not None: llama_cpp.llama_sample_grammar( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore - grammar=self.grammar.grammar, + grammar=grammar.grammar, ) if temp.value == 0.0: @@ -650,10 +643,10 @@ class Llama: ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) - if self.grammar is not None: + if grammar is not None: llama_cpp.llama_grammar_accept_token( ctx=self.ctx, - grammar=self.grammar.grammar, + grammar=grammar.grammar, token=llama_cpp.ctypes.c_int(id), ) return id @@ -672,6 +665,7 @@ class Llama: mirostat_tau: float = 5.0, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ): """Sample a token from the model. @@ -705,6 +699,7 @@ class Llama: mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, logits_processor=logits_processor, + grammar=grammar, ) def generate( @@ -723,6 +718,7 @@ class Llama: mirostat_eta: float = 0.1, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -761,8 +757,8 @@ class Llama: if reset: self.reset() - if self.grammar is not None: - self.grammar.reset() + if grammar is not None: + grammar.reset() while True: self.eval(tokens) @@ -778,6 +774,7 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, logits_processor=logits_processor, + grammar=grammar, ) if stopping_criteria is not None and stopping_criteria( self._input_ids.tolist(), self._scores[-1, :].tolist() @@ -880,6 +877,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -957,6 +955,7 @@ class Llama: repeat_penalty=repeat_penalty, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -1301,6 +1300,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1345,6 +1345,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1374,6 +1375,7 @@ class Llama: model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1418,6 +1420,7 @@ class Llama: model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, + grammar=grammar, ) def _convert_text_completion_to_chat( @@ -1498,6 +1501,7 @@ class Llama: mirostat_eta: float = 0.1, model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. @@ -1540,6 +1544,7 @@ class Llama: mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, + grammar=grammar, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore