Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
This commit is contained in:
commit
9d7c8307cd
3 changed files with 54 additions and 2 deletions
|
@ -1046,6 +1046,8 @@ class Llama:
|
|||
self,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.95,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
temp: float = 0.80,
|
||||
repeat_penalty: float = 1.1,
|
||||
frequency_penalty: float = 0.0,
|
||||
|
@ -1108,7 +1110,10 @@ class Llama:
|
|||
grammar=grammar,
|
||||
)
|
||||
|
||||
if temp == 0.0:
|
||||
if temp < 0.0:
|
||||
self._ctx.sample_softmax(candidates=self._candidates)
|
||||
id = self._candidates.candidates.data[0].id
|
||||
elif temp == 0.0:
|
||||
id = self._ctx.sample_token_greedy(candidates=self._candidates)
|
||||
elif mirostat_mode == 1:
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
|
@ -1130,8 +1135,9 @@ class Llama:
|
|||
else:
|
||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
|
||||
self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1)
|
||||
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
|
||||
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
|
||||
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token(candidates=self._candidates)
|
||||
if grammar is not None:
|
||||
|
@ -1143,6 +1149,8 @@ class Llama:
|
|||
tokens: Sequence[int],
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.95,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
temp: float = 0.80,
|
||||
repeat_penalty: float = 1.1,
|
||||
reset: bool = True,
|
||||
|
@ -1200,6 +1208,8 @@ class Llama:
|
|||
token = self.sample(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
repeat_penalty=repeat_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -1298,6 +1308,8 @@ class Llama:
|
|||
max_tokens: Optional[int] = 16,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
|
@ -1398,6 +1410,8 @@ class Llama:
|
|||
prompt_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
temp=temperature,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
|
@ -1772,6 +1786,8 @@ class Llama:
|
|||
max_tokens: Optional[int] = 16,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
|
@ -1818,6 +1834,8 @@ class Llama:
|
|||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
stop=stop,
|
||||
|
@ -1849,6 +1867,8 @@ class Llama:
|
|||
max_tokens: int = 128,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
|
@ -1895,6 +1915,8 @@ class Llama:
|
|||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
stop=stop,
|
||||
|
@ -1924,6 +1946,8 @@ class Llama:
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
|
@ -1970,6 +1994,8 @@ class Llama:
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
|
|
@ -26,6 +26,8 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
|
@ -287,6 +289,8 @@ def register_chat_format(name: str):
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
|
@ -330,6 +334,8 @@ def register_chat_format(name: str):
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
@ -579,6 +585,8 @@ def functionary_chat_handler(
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
|
||||
|
@ -761,6 +769,8 @@ def functionary_chat_handler(
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=["user:", "</s>"],
|
||||
max_tokens=max_tokens,
|
||||
|
@ -831,6 +841,8 @@ def functionary_chat_handler(
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
|
@ -929,6 +941,8 @@ class Llava15ChatHandler:
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
response_format: Optional[
|
||||
|
@ -1045,6 +1059,8 @@ class Llava15ChatHandler:
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
|
|
|
@ -521,6 +521,14 @@ top_p_field = Field(
|
|||
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
|
||||
)
|
||||
|
||||
min_p_field = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Sets a minimum base probability threshold for token selection.\n\n"
|
||||
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
|
||||
)
|
||||
|
||||
stop_field = Field(
|
||||
default=None,
|
||||
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
||||
|
@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel):
|
|||
max_tokens: int = max_tokens_field
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||
|
@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
)
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
stop: Optional[List[str]] = stop_field
|
||||
stream: bool = stream_field
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
|
|
Loading…
Reference in a new issue