Fix: default max_tokens matches openai api (16 for completion, max length for chat completion)
This commit is contained in:
parent
82072802ea
commit
e7962d2c73
2 changed files with 9 additions and 7 deletions
|
@ -1296,7 +1296,7 @@ class Llama:
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[int]],
|
prompt: Union[str, List[int]],
|
||||||
suffix: Optional[str] = None,
|
suffix: Optional[str] = None,
|
||||||
max_tokens: int = 16,
|
max_tokens: Optional[int] = 16,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
|
@ -1350,7 +1350,7 @@ class Llama:
|
||||||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if max_tokens <= 0:
|
if max_tokens is None or max_tokens <= 0:
|
||||||
# Unlimited, depending on n_ctx.
|
# Unlimited, depending on n_ctx.
|
||||||
max_tokens = self._n_ctx - len(prompt_tokens)
|
max_tokens = self._n_ctx - len(prompt_tokens)
|
||||||
|
|
||||||
|
@ -1762,7 +1762,7 @@ class Llama:
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[int]],
|
prompt: Union[str, List[int]],
|
||||||
suffix: Optional[str] = None,
|
suffix: Optional[str] = None,
|
||||||
max_tokens: int = 128,
|
max_tokens: Optional[int] = 16,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
|
@ -1788,7 +1788,7 @@ class Llama:
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to generate text from.
|
prompt: The prompt to generate text from.
|
||||||
suffix: A suffix to append to the generated text. If None, no suffix is appended.
|
suffix: A suffix to append to the generated text. If None, no suffix is appended.
|
||||||
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
||||||
temperature: The temperature to use for sampling.
|
temperature: The temperature to use for sampling.
|
||||||
top_p: The top-p value to use for sampling.
|
top_p: The top-p value to use for sampling.
|
||||||
logprobs: The number of logprobs to return. If None, no logprobs are returned.
|
logprobs: The number of logprobs to return. If None, no logprobs are returned.
|
||||||
|
@ -1921,7 +1921,7 @@ class Llama:
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
|
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
@ -1944,7 +1944,7 @@ class Llama:
|
||||||
top_k: The top-k value to use for sampling.
|
top_k: The top-k value to use for sampling.
|
||||||
stream: Whether to stream the results.
|
stream: Whether to stream the results.
|
||||||
stop: A list of strings to stop generation when encountered.
|
stop: A list of strings to stop generation when encountered.
|
||||||
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.
|
||||||
repeat_penalty: The penalty to apply to repeated tokens.
|
repeat_penalty: The penalty to apply to repeated tokens.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -783,7 +783,9 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="A tool to apply to the generated completions.",
|
description="A tool to apply to the generated completions.",
|
||||||
) # TODO: verify
|
) # TODO: verify
|
||||||
max_tokens: int = max_tokens_field
|
max_tokens: Optional[int] = Field(
|
||||||
|
default=None, description="The maximum number of tokens to generate. Defaults to inf"
|
||||||
|
)
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
top_p: float = top_p_field
|
top_p: float = top_p_field
|
||||||
stop: Optional[List[str]] = stop_field
|
stop: Optional[List[str]] = stop_field
|
||||||
|
|
Loading…
Add table
Reference in a new issue