From 4b01a873efdef1c0f52dad87f11b9340ec8adc13 Mon Sep 17 00:00:00 2001 From: swg Date: Fri, 22 Dec 2023 14:05:13 -0500 Subject: [PATCH] server: Support none defaulting to infinity for completions (#111) * Support defaulting to infinity or -1 for chat completions * Check if completion_tokens is none in error handler. * fix: max_tokens in create completion should match openai spec * Fix __call__ --------- Co-authored-by: Andrei Betlen --- llama_cpp/llama.py | 4 ++-- llama_cpp/server/errors.py | 2 +- llama_cpp/server/types.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ef5e347..788732b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1917,7 +1917,7 @@ class Llama: completion_or_chunks = self._create_completion( prompt=prompt, suffix=suffix, - max_tokens=max_tokens, + max_tokens=-1 if max_tokens is None else max_tokens, temperature=temperature, top_p=top_p, min_p=min_p, @@ -1951,7 +1951,7 @@ class Llama: self, prompt: str, suffix: Optional[str] = None, - max_tokens: int = 128, + max_tokens: Optional[int] = 16, temperature: float = 0.8, top_p: float = 0.95, min_p: float = 0.05, diff --git a/llama_cpp/server/errors.py b/llama_cpp/server/errors.py index febe3e3..9d3d355 100644 --- a/llama_cpp/server/errors.py +++ b/llama_cpp/server/errors.py @@ -72,7 +72,7 @@ class ErrorResponseFormatters: return 400, ErrorResponse( message=message.format( context_window, - completion_tokens + prompt_tokens, + (completion_tokens or 0) + prompt_tokens, prompt_tokens, completion_tokens, ), # type: ignore diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index f0867bc..f0827d7 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -110,7 +110,9 @@ class CreateCompletionRequest(BaseModel): default=None, description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", ) - max_tokens: int = max_tokens_field + max_tokens: Optional[int] = Field( + default=16, ge=0, description="The maximum number of tokens to generate." + ) temperature: float = temperature_field top_p: float = top_p_field min_p: float = min_p_field