Fix: default max_tokens matches openai api (16 for completion, max length for chat completion)

This commit is contained in:
Andrei Betlen 2023-11-10 02:49:27 -05:00
parent 82072802ea
commit e7962d2c73
2 changed files with 9 additions and 7 deletions

View file

@ -1296,7 +1296,7 @@ class Llama:
self, self,
prompt: Union[str, List[int]], prompt: Union[str, List[int]],
suffix: Optional[str] = None, suffix: Optional[str] = None,
max_tokens: int = 16, max_tokens: Optional[int] = 16,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@ -1350,7 +1350,7 @@ class Llama:
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
) )
if max_tokens <= 0: if max_tokens is None or max_tokens <= 0:
# Unlimited, depending on n_ctx. # Unlimited, depending on n_ctx.
max_tokens = self._n_ctx - len(prompt_tokens) max_tokens = self._n_ctx - len(prompt_tokens)
@ -1762,7 +1762,7 @@ class Llama:
self, self,
prompt: Union[str, List[int]], prompt: Union[str, List[int]],
suffix: Optional[str] = None, suffix: Optional[str] = None,
max_tokens: int = 128, max_tokens: Optional[int] = 16,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@ -1788,7 +1788,7 @@ class Llama:
Args: Args:
prompt: The prompt to generate text from. prompt: The prompt to generate text from.
suffix: A suffix to append to the generated text. If None, no suffix is appended. suffix: A suffix to append to the generated text. If None, no suffix is appended.
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx. max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
temperature: The temperature to use for sampling. temperature: The temperature to use for sampling.
top_p: The top-p value to use for sampling. top_p: The top-p value to use for sampling.
logprobs: The number of logprobs to return. If None, no logprobs are returned. logprobs: The number of logprobs to return. If None, no logprobs are returned.
@ -1921,7 +1921,7 @@ class Llama:
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
response_format: Optional[ChatCompletionRequestResponseFormat] = None, response_format: Optional[ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256, max_tokens: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -1944,7 +1944,7 @@ class Llama:
top_k: The top-k value to use for sampling. top_k: The top-k value to use for sampling.
stream: Whether to stream the results. stream: Whether to stream the results.
stop: A list of strings to stop generation when encountered. stop: A list of strings to stop generation when encountered.
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx. max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
repeat_penalty: The penalty to apply to repeated tokens. repeat_penalty: The penalty to apply to repeated tokens.
Returns: Returns:

View file

@ -783,7 +783,9 @@ class CreateChatCompletionRequest(BaseModel):
default=None, default=None,
description="A tool to apply to the generated completions.", description="A tool to apply to the generated completions.",
) # TODO: verify ) # TODO: verify
max_tokens: int = max_tokens_field max_tokens: Optional[int] = Field(
default=None, description="The maximum number of tokens to generate. Defaults to inf"
)
temperature: float = temperature_field temperature: float = temperature_field
top_p: float = top_p_field top_p: float = top_p_field
stop: Optional[List[str]] = stop_field stop: Optional[List[str]] = stop_field