Move grammar to function call argument
This commit is contained in:
parent
1e844d3238
commit
66fb0345e8
1 changed files with 19 additions and 14 deletions
|
@ -227,7 +227,6 @@ class Llama:
|
||||||
tensor_split: Optional[List[float]] = None,
|
tensor_split: Optional[List[float]] = None,
|
||||||
rope_freq_base: float = 10000.0,
|
rope_freq_base: float = 10000.0,
|
||||||
rope_freq_scale: float = 1.0,
|
rope_freq_scale: float = 1.0,
|
||||||
grammar: Optional[Union[str, Path]] = None,
|
|
||||||
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
||||||
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
|
||||||
mul_mat_q: Optional[bool] = None, # (TEMPORARY)
|
mul_mat_q: Optional[bool] = None, # (TEMPORARY)
|
||||||
|
@ -254,7 +253,6 @@ class Llama:
|
||||||
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
|
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
|
||||||
rope_freq_base: Base frequency for rope sampling.
|
rope_freq_base: Base frequency for rope sampling.
|
||||||
rope_freq_scale: Scale factor for rope sampling.
|
rope_freq_scale: Scale factor for rope sampling.
|
||||||
grammar: Path to a BNF grammar file to use for grammar based sampling.
|
|
||||||
verbose: Print verbose output to stderr.
|
verbose: Print verbose output to stderr.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -383,12 +381,6 @@ class Llama:
|
||||||
self.scores: npt.NDArray[np.single] = np.ndarray(
|
self.scores: npt.NDArray[np.single] = np.ndarray(
|
||||||
(n_ctx, self._n_vocab), dtype=np.single
|
(n_ctx, self._n_vocab), dtype=np.single
|
||||||
)
|
)
|
||||||
if grammar is not None:
|
|
||||||
self.grammar = LlamaGrammar.from_file(
|
|
||||||
grammar, verbose=verbose
|
|
||||||
) # type: Optional[LlamaGrammar]
|
|
||||||
else:
|
|
||||||
self.grammar = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _input_ids(self) -> npt.NDArray[np.intc]:
|
def _input_ids(self) -> npt.NDArray[np.intc]:
|
||||||
|
@ -527,6 +519,7 @@ class Llama:
|
||||||
mirostat_eta: llama_cpp.c_float,
|
mirostat_eta: llama_cpp.c_float,
|
||||||
penalize_nl: bool = True,
|
penalize_nl: bool = True,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert self.n_tokens > 0
|
assert self.n_tokens > 0
|
||||||
|
@ -574,11 +567,11 @@ class Llama:
|
||||||
if not penalize_nl:
|
if not penalize_nl:
|
||||||
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
|
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
|
||||||
|
|
||||||
if self.grammar is not None:
|
if grammar is not None:
|
||||||
llama_cpp.llama_sample_grammar(
|
llama_cpp.llama_sample_grammar(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
grammar=self.grammar.grammar,
|
grammar=grammar.grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
if temp.value == 0.0:
|
if temp.value == 0.0:
|
||||||
|
@ -650,10 +643,10 @@ class Llama:
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
)
|
)
|
||||||
if self.grammar is not None:
|
if grammar is not None:
|
||||||
llama_cpp.llama_grammar_accept_token(
|
llama_cpp.llama_grammar_accept_token(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
grammar=self.grammar.grammar,
|
grammar=grammar.grammar,
|
||||||
token=llama_cpp.ctypes.c_int(id),
|
token=llama_cpp.ctypes.c_int(id),
|
||||||
)
|
)
|
||||||
return id
|
return id
|
||||||
|
@ -672,6 +665,7 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
penalize_nl: bool = True,
|
penalize_nl: bool = True,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
):
|
):
|
||||||
"""Sample a token from the model.
|
"""Sample a token from the model.
|
||||||
|
|
||||||
|
@ -705,6 +699,7 @@ class Llama:
|
||||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||||
penalize_nl=penalize_nl,
|
penalize_nl=penalize_nl,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -723,6 +718,7 @@ class Llama:
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
) -> Generator[int, Optional[Sequence[int]], None]:
|
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
|
@ -761,8 +757,8 @@ class Llama:
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
if self.grammar is not None:
|
if grammar is not None:
|
||||||
self.grammar.reset()
|
grammar.reset()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
|
@ -778,6 +774,7 @@ class Llama:
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
if stopping_criteria is not None and stopping_criteria(
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||||
|
@ -880,6 +877,7 @@ class Llama:
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
|
@ -957,6 +955,7 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
):
|
):
|
||||||
if token == self._token_eos:
|
if token == self._token_eos:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
|
@ -1301,6 +1300,7 @@ class Llama:
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -1345,6 +1345,7 @@ class Llama:
|
||||||
model=model,
|
model=model,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||||
|
@ -1374,6 +1375,7 @@ class Llama:
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -1418,6 +1420,7 @@ class Llama:
|
||||||
model=model,
|
model=model,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
def _convert_text_completion_to_chat(
|
||||||
|
@ -1498,6 +1501,7 @@ class Llama:
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||||
"""Generate a chat completion from a list of messages.
|
"""Generate a chat completion from a list of messages.
|
||||||
|
|
||||||
|
@ -1540,6 +1544,7 @@ class Llama:
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue