Fix: default max_tokens matches openai api (16 for completion, max length for chat completion)

This commit is contained in:
Andrei Betlen 2023-11-10 02:49:27 -05:00
parent 82072802ea
commit e7962d2c73
2 changed files with 9 additions and 7 deletions

View file

@ -1296,7 +1296,7 @@ class Llama:
self,
prompt: Union[str, List[int]],
suffix: Optional[str] = None,
max_tokens: int = 16,
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
@ -1350,7 +1350,7 @@ class Llama:
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
)
if max_tokens <= 0:
if max_tokens is None or max_tokens <= 0:
# Unlimited, depending on n_ctx.
max_tokens = self._n_ctx - len(prompt_tokens)
@ -1762,7 +1762,7 @@ class Llama:
self,
prompt: Union[str, List[int]],
suffix: Optional[str] = None,
max_tokens: int = 128,
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
@ -1788,7 +1788,7 @@ class Llama:
Args:
prompt: The prompt to generate text from.
suffix: A suffix to append to the generated text. If None, no suffix is appended.
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx.
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
temperature: The temperature to use for sampling.
top_p: The top-p value to use for sampling.
logprobs: The number of logprobs to return. If None, no logprobs are returned.
@ -1921,7 +1921,7 @@ class Llama:
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
@ -1944,7 +1944,7 @@ class Llama:
top_k: The top-k value to use for sampling.
stream: Whether to stream the results.
stop: A list of strings to stop generation when encountered.
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx.
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
repeat_penalty: The penalty to apply to repeated tokens.
Returns:

View file

@ -783,7 +783,9 @@ class CreateChatCompletionRequest(BaseModel):
default=None,
description="A tool to apply to the generated completions.",
) # TODO: verify
max_tokens: int = max_tokens_field
max_tokens: Optional[int] = Field(
default=None, description="The maximum number of tokens to generate. Defaults to inf"
)
temperature: float = temperature_field
top_p: float = top_p_field
stop: Optional[List[str]] = stop_field