Move grammar to function call argument

This commit is contained in:
Andrei Betlen 2023-08-08 15:08:54 -04:00
parent 1e844d3238
commit 66fb0345e8

View file

@ -227,7 +227,6 @@ class Llama:
tensor_split: Optional[List[float]] = None, tensor_split: Optional[List[float]] = None,
rope_freq_base: float = 10000.0, rope_freq_base: float = 10000.0,
rope_freq_scale: float = 1.0, rope_freq_scale: float = 1.0,
grammar: Optional[Union[str, Path]] = None,
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
rms_norm_eps: Optional[float] = None, # (TEMPORARY) rms_norm_eps: Optional[float] = None, # (TEMPORARY)
mul_mat_q: Optional[bool] = None, # (TEMPORARY) mul_mat_q: Optional[bool] = None, # (TEMPORARY)
@ -254,7 +253,6 @@ class Llama:
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
rope_freq_base: Base frequency for rope sampling. rope_freq_base: Base frequency for rope sampling.
rope_freq_scale: Scale factor for rope sampling. rope_freq_scale: Scale factor for rope sampling.
grammar: Path to a BNF grammar file to use for grammar based sampling.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -383,12 +381,6 @@ class Llama:
self.scores: npt.NDArray[np.single] = np.ndarray( self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx, self._n_vocab), dtype=np.single (n_ctx, self._n_vocab), dtype=np.single
) )
if grammar is not None:
self.grammar = LlamaGrammar.from_file(
grammar, verbose=verbose
) # type: Optional[LlamaGrammar]
else:
self.grammar = None
@property @property
def _input_ids(self) -> npt.NDArray[np.intc]: def _input_ids(self) -> npt.NDArray[np.intc]:
@ -527,6 +519,7 @@ class Llama:
mirostat_eta: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
): ):
assert self.ctx is not None assert self.ctx is not None
assert self.n_tokens > 0 assert self.n_tokens > 0
@ -574,11 +567,11 @@ class Llama:
if not penalize_nl: if not penalize_nl:
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
if self.grammar is not None: if grammar is not None:
llama_cpp.llama_sample_grammar( llama_cpp.llama_sample_grammar(
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
grammar=self.grammar.grammar, grammar=grammar.grammar,
) )
if temp.value == 0.0: if temp.value == 0.0:
@ -650,10 +643,10 @@ class Llama:
ctx=self.ctx, ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
) )
if self.grammar is not None: if grammar is not None:
llama_cpp.llama_grammar_accept_token( llama_cpp.llama_grammar_accept_token(
ctx=self.ctx, ctx=self.ctx,
grammar=self.grammar.grammar, grammar=grammar.grammar,
token=llama_cpp.ctypes.c_int(id), token=llama_cpp.ctypes.c_int(id),
) )
return id return id
@ -672,6 +665,7 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -705,6 +699,7 @@ class Llama:
mirostat_eta=llama_cpp.c_float(mirostat_eta), mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl, penalize_nl=penalize_nl,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
def generate( def generate(
@ -723,6 +718,7 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Generator[int, Optional[Sequence[int]], None]: ) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
@ -761,8 +757,8 @@ class Llama:
if reset: if reset:
self.reset() self.reset()
if self.grammar is not None: if grammar is not None:
self.grammar.reset() grammar.reset()
while True: while True:
self.eval(tokens) self.eval(tokens)
@ -778,6 +774,7 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids.tolist(), self._scores[-1, :].tolist() self._input_ids.tolist(), self._scores[-1, :].tolist()
@ -880,6 +877,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
@ -957,6 +955,7 @@ class Llama:
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
): ):
if token == self._token_eos: if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -1301,6 +1300,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1345,6 +1345,7 @@ class Llama:
model=model, model=model,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks
@ -1374,6 +1375,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1418,6 +1420,7 @@ class Llama:
model=model, model=model,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
def _convert_text_completion_to_chat( def _convert_text_completion_to_chat(
@ -1498,6 +1501,7 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -1540,6 +1544,7 @@ class Llama:
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore