Fix: default max_tokens matches openai api (16 for completion, max length for chat completion)
This commit is contained in:
parent
82072802ea
commit
e7962d2c73
2 changed files with 9 additions and 7 deletions
|
@ -1296,7 +1296,7 @@ class Llama:
|
|||
self,
|
||||
prompt: Union[str, List[int]],
|
||||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 16,
|
||||
max_tokens: Optional[int] = 16,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
|
@ -1350,7 +1350,7 @@ class Llama:
|
|||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
)
|
||||
|
||||
if max_tokens <= 0:
|
||||
if max_tokens is None or max_tokens <= 0:
|
||||
# Unlimited, depending on n_ctx.
|
||||
max_tokens = self._n_ctx - len(prompt_tokens)
|
||||
|
||||
|
@ -1762,7 +1762,7 @@ class Llama:
|
|||
self,
|
||||
prompt: Union[str, List[int]],
|
||||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 128,
|
||||
max_tokens: Optional[int] = 16,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
|
@ -1788,7 +1788,7 @@ class Llama:
|
|||
Args:
|
||||
prompt: The prompt to generate text from.
|
||||
suffix: A suffix to append to the generated text. If None, no suffix is appended.
|
||||
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
||||
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
||||
temperature: The temperature to use for sampling.
|
||||
top_p: The top-p value to use for sampling.
|
||||
logprobs: The number of logprobs to return. If None, no logprobs are returned.
|
||||
|
@ -1921,7 +1921,7 @@ class Llama:
|
|||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
|
||||
max_tokens: int = 256,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
|
@ -1944,7 +1944,7 @@ class Llama:
|
|||
top_k: The top-k value to use for sampling.
|
||||
stream: Whether to stream the results.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
||||
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
||||
repeat_penalty: The penalty to apply to repeated tokens.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -783,7 +783,9 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
default=None,
|
||||
description="A tool to apply to the generated completions.",
|
||||
) # TODO: verify
|
||||
max_tokens: int = max_tokens_field
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None, description="The maximum number of tokens to generate. Defaults to inf"
|
||||
)
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
stop: Optional[List[str]] = stop_field
|
||||
|
|
Loading…
Reference in a new issue