From e7962d2c733cbbeec5a37392c81f64185a9a39e8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 10 Nov 2023 02:49:27 -0500 Subject: [PATCH] Fix: default max_tokens matches openai api (16 for completion, max length for chat completion) --- llama_cpp/llama.py | 12 ++++++------ llama_cpp/server/app.py | 4 +++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2703970..1e78221 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1296,7 +1296,7 @@ class Llama: self, prompt: Union[str, List[int]], suffix: Optional[str] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = 16, temperature: float = 0.8, top_p: float = 0.95, logprobs: Optional[int] = None, @@ -1350,7 +1350,7 @@ class Llama: f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" ) - if max_tokens <= 0: + if max_tokens is None or max_tokens <= 0: # Unlimited, depending on n_ctx. max_tokens = self._n_ctx - len(prompt_tokens) @@ -1762,7 +1762,7 @@ class Llama: self, prompt: Union[str, List[int]], suffix: Optional[str] = None, - max_tokens: int = 128, + max_tokens: Optional[int] = 16, temperature: float = 0.8, top_p: float = 0.95, logprobs: Optional[int] = None, @@ -1788,7 +1788,7 @@ class Llama: Args: prompt: The prompt to generate text from. suffix: A suffix to append to the generated text. If None, no suffix is appended. - max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx. + max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. temperature: The temperature to use for sampling. top_p: The top-p value to use for sampling. logprobs: The number of logprobs to return. If None, no logprobs are returned. @@ -1921,7 +1921,7 @@ class Llama: stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, response_format: Optional[ChatCompletionRequestResponseFormat] = None, - max_tokens: int = 256, + max_tokens: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -1944,7 +1944,7 @@ class Llama: top_k: The top-k value to use for sampling. stream: Whether to stream the results. stop: A list of strings to stop generation when encountered. - max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx. + max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. repeat_penalty: The penalty to apply to repeated tokens. Returns: diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 3baeee8..2a6aed8 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -783,7 +783,9 @@ class CreateChatCompletionRequest(BaseModel): default=None, description="A tool to apply to the generated completions.", ) # TODO: verify - max_tokens: int = max_tokens_field + max_tokens: Optional[int] = Field( + default=None, description="The maximum number of tokens to generate. Defaults to inf" + ) temperature: float = temperature_field top_p: float = top_p_field stop: Optional[List[str]] = stop_field