Add extra logits_processor and stopping_criteria
This commit is contained in:
parent
30bf8ec557
commit
433a2e3e8a
1 changed files with 17 additions and 0 deletions
|
@ -677,6 +677,8 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
|
@ -739,6 +741,8 @@ class Llama:
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
logits_processor=logits_processor,
|
||||||
):
|
):
|
||||||
if token == self._token_eos:
|
if token == self._token_eos:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
|
@ -848,6 +852,11 @@ class Llama:
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
|
list(self.eval_tokens), self.eval_logits[-1]
|
||||||
|
):
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_print_timings(self.ctx)
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
|
|
||||||
|
@ -1049,6 +1058,8 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -1091,6 +1102,8 @@ class Llama:
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||||
|
@ -1118,6 +1131,8 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -1160,6 +1175,8 @@ class Llama:
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
def _convert_text_completion_to_chat(
|
||||||
|
|
Loading…
Reference in a new issue