Added support for min_p (#921)
* Added support for min_p My small contribution to this great project. Ref: https://github.com/ggerganov/llama.cpp/pull/3841 Closes: https://github.com/abetlen/llama-cpp-python/issues/911 * Fix for negative temp (sample_softmax)
This commit is contained in:
parent
63fe1370ed
commit
b8438f70b5
3 changed files with 54 additions and 2 deletions
|
@ -1046,6 +1046,8 @@ class Llama:
|
||||||
self,
|
self,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
temp: float = 0.80,
|
temp: float = 0.80,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
@ -1108,7 +1110,10 @@ class Llama:
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
if temp == 0.0:
|
if temp < 0.0:
|
||||||
|
self._ctx.sample_softmax(candidates=self._candidates)
|
||||||
|
id = self._candidates.candidates.data[0].id
|
||||||
|
elif temp == 0.0:
|
||||||
id = self._ctx.sample_token_greedy(candidates=self._candidates)
|
id = self._ctx.sample_token_greedy(candidates=self._candidates)
|
||||||
elif mirostat_mode == 1:
|
elif mirostat_mode == 1:
|
||||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||||
|
@ -1130,8 +1135,9 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||||
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
|
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
|
||||||
self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1)
|
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
|
||||||
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
|
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
|
||||||
|
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
|
||||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||||
id = self._ctx.sample_token(candidates=self._candidates)
|
id = self._ctx.sample_token(candidates=self._candidates)
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
|
@ -1143,6 +1149,8 @@ class Llama:
|
||||||
tokens: Sequence[int],
|
tokens: Sequence[int],
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
temp: float = 0.80,
|
temp: float = 0.80,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
reset: bool = True,
|
reset: bool = True,
|
||||||
|
@ -1200,6 +1208,8 @@ class Llama:
|
||||||
token = self.sample(
|
token = self.sample(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
temp=temp,
|
temp=temp,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
@ -1298,6 +1308,8 @@ class Llama:
|
||||||
max_tokens: Optional[int] = 16,
|
max_tokens: Optional[int] = 16,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
@ -1398,6 +1410,8 @@ class Llama:
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
tfs_z=tfs_z,
|
tfs_z=tfs_z,
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
|
@ -1772,6 +1786,8 @@ class Llama:
|
||||||
max_tokens: Optional[int] = 16,
|
max_tokens: Optional[int] = 16,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
@ -1818,6 +1834,8 @@ class Llama:
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
@ -1849,6 +1867,8 @@ class Llama:
|
||||||
max_tokens: int = 128,
|
max_tokens: int = 128,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
@ -1895,6 +1915,8 @@ class Llama:
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
@ -1924,6 +1946,8 @@ class Llama:
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -1970,6 +1994,8 @@ class Llama:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
|
|
@ -26,6 +26,8 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -287,6 +289,8 @@ def register_chat_format(name: str):
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -330,6 +334,8 @@ def register_chat_format(name: str):
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -579,6 +585,8 @@ def functionary_chat_handler(
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
|
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
|
||||||
|
@ -761,6 +769,8 @@ def functionary_chat_handler(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=["user:", "</s>"],
|
stop=["user:", "</s>"],
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
@ -831,6 +841,8 @@ def functionary_chat_handler(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
|
@ -929,6 +941,8 @@ class Llava15ChatHandler:
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
response_format: Optional[
|
response_format: Optional[
|
||||||
|
@ -1045,6 +1059,8 @@ class Llava15ChatHandler:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|
|
@ -521,6 +521,14 @@ top_p_field = Field(
|
||||||
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
|
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
min_p_field = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Sets a minimum base probability threshold for token selection.\n\n"
|
||||||
|
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
|
||||||
|
)
|
||||||
|
|
||||||
stop_field = Field(
|
stop_field = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
||||||
|
@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel):
|
||||||
max_tokens: int = max_tokens_field
|
max_tokens: int = max_tokens_field
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
top_p: float = top_p_field
|
top_p: float = top_p_field
|
||||||
|
min_p: float = min_p_field
|
||||||
echo: bool = Field(
|
echo: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||||
|
@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
)
|
)
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
top_p: float = top_p_field
|
top_p: float = top_p_field
|
||||||
|
min_p: float = min_p_field
|
||||||
stop: Optional[List[str]] = stop_field
|
stop: Optional[List[str]] = stop_field
|
||||||
stream: bool = stream_field
|
stream: bool = stream_field
|
||||||
presence_penalty: Optional[float] = presence_penalty_field
|
presence_penalty: Optional[float] = presence_penalty_field
|
||||||
|
|
Loading…
Reference in a new issue