Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
This commit is contained in:
commit
e811a81066
3 changed files with 49 additions and 5 deletions
|
@ -410,8 +410,8 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||||
|
|
||||||
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
|
eos_token_id = self.token_eos()
|
||||||
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
|
bos_token_id = self.token_bos()
|
||||||
|
|
||||||
eos_token = self._model.token_get_text(eos_token_id)
|
eos_token = self._model.token_get_text(eos_token_id)
|
||||||
bos_token = self._model.token_get_text(bos_token_id)
|
bos_token = self._model.token_get_text(bos_token_id)
|
||||||
|
@ -961,9 +961,9 @@ class Llama:
|
||||||
|
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created: int = int(time.time())
|
created: int = int(time.time())
|
||||||
prefix_token_id: int = int(self.metadata.get("tokenizer.ggml.prefix_token_id", self._model.token_prefix()))
|
prefix_token_id: int = self._model.token_prefix()
|
||||||
middle_token_id: int = int(self.metadata.get("tokenizer.ggml.middle_token_id", self._model.token_middle()))
|
middle_token_id: int = self._model.token_middle()
|
||||||
suffix_token_id: int = int(self.metadata.get("tokenizer.ggml.suffix_token_id", self._model.token_suffix()))
|
suffix_token_id: int = self._model.token_suffix()
|
||||||
# If prompt is empty, initialize completion with BOS token to avoid
|
# If prompt is empty, initialize completion with BOS token to avoid
|
||||||
# detokenization including a space at the beginning of the completion
|
# detokenization including a space at the beginning of the completion
|
||||||
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
|
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
|
||||||
|
@ -2084,3 +2084,19 @@ class StoppingCriteriaList(List[StoppingCriteria]):
|
||||||
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
||||||
|
|
||||||
|
|
||||||
|
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, min_tokens: int, token_eos: int):
|
||||||
|
self.min_tokens = min_tokens
|
||||||
|
self.token_eos = token_eos
|
||||||
|
self.prompt_tokens = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
|
||||||
|
) -> npt.NDArray[np.single]:
|
||||||
|
if self.prompt_tokens is None:
|
||||||
|
self.prompt_tokens = len(input_ids)
|
||||||
|
if len(input_ids) - self.prompt_tokens < self.min_tokens:
|
||||||
|
scores[self.token_eos] = -np.inf
|
||||||
|
return scores
|
||||||
|
|
|
@ -275,6 +275,7 @@ async def create_completion(
|
||||||
"best_of",
|
"best_of",
|
||||||
"logit_bias_type",
|
"logit_bias_type",
|
||||||
"user",
|
"user",
|
||||||
|
"min_tokens",
|
||||||
}
|
}
|
||||||
kwargs = body.model_dump(exclude=exclude)
|
kwargs = body.model_dump(exclude=exclude)
|
||||||
|
|
||||||
|
@ -288,6 +289,15 @@ async def create_completion(
|
||||||
if body.grammar is not None:
|
if body.grammar is not None:
|
||||||
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
||||||
|
|
||||||
|
if body.min_tokens > 0:
|
||||||
|
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
|
||||||
|
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
|
||||||
|
)
|
||||||
|
if "logits_processor" not in kwargs:
|
||||||
|
kwargs["logits_processor"] = _min_tokens_logits_processor
|
||||||
|
else:
|
||||||
|
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
|
||||||
|
|
||||||
iterator_or_completion: Union[
|
iterator_or_completion: Union[
|
||||||
llama_cpp.CreateCompletionResponse,
|
llama_cpp.CreateCompletionResponse,
|
||||||
Iterator[llama_cpp.CreateCompletionStreamResponse],
|
Iterator[llama_cpp.CreateCompletionStreamResponse],
|
||||||
|
@ -445,6 +455,7 @@ async def create_chat_completion(
|
||||||
"n",
|
"n",
|
||||||
"logit_bias_type",
|
"logit_bias_type",
|
||||||
"user",
|
"user",
|
||||||
|
"min_tokens",
|
||||||
}
|
}
|
||||||
kwargs = body.model_dump(exclude=exclude)
|
kwargs = body.model_dump(exclude=exclude)
|
||||||
llama = llama_proxy(body.model)
|
llama = llama_proxy(body.model)
|
||||||
|
@ -458,6 +469,15 @@ async def create_chat_completion(
|
||||||
if body.grammar is not None:
|
if body.grammar is not None:
|
||||||
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
||||||
|
|
||||||
|
if body.min_tokens > 0:
|
||||||
|
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
|
||||||
|
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
|
||||||
|
)
|
||||||
|
if "logits_processor" not in kwargs:
|
||||||
|
kwargs["logits_processor"] = _min_tokens_logits_processor
|
||||||
|
else:
|
||||||
|
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
|
||||||
|
|
||||||
iterator_or_completion: Union[
|
iterator_or_completion: Union[
|
||||||
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
|
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
|
||||||
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
||||||
|
|
|
@ -16,6 +16,12 @@ max_tokens_field = Field(
|
||||||
default=16, ge=1, description="The maximum number of tokens to generate."
|
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
min_tokens_field = Field(
|
||||||
|
default=0,
|
||||||
|
ge=0,
|
||||||
|
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
|
||||||
|
)
|
||||||
|
|
||||||
temperature_field = Field(
|
temperature_field = Field(
|
||||||
default=0.8,
|
default=0.8,
|
||||||
description="Adjust the randomness of the generated text.\n\n"
|
description="Adjust the randomness of the generated text.\n\n"
|
||||||
|
@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
|
||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default=16, ge=0, description="The maximum number of tokens to generate."
|
default=16, ge=0, description="The maximum number of tokens to generate."
|
||||||
)
|
)
|
||||||
|
min_tokens: int = min_tokens_field
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
top_p: float = top_p_field
|
top_p: float = top_p_field
|
||||||
min_p: float = min_p_field
|
min_p: float = min_p_field
|
||||||
|
@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="The maximum number of tokens to generate. Defaults to inf",
|
description="The maximum number of tokens to generate. Defaults to inf",
|
||||||
)
|
)
|
||||||
|
min_tokens: int = min_tokens_field
|
||||||
logprobs: Optional[bool] = Field(
|
logprobs: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to output the logprobs or not. Default is True"
|
description="Whether to output the logprobs or not. Default is True"
|
||||||
|
|
Loading…
Reference in a new issue