Added support for min_p (#921)

* Added support for min_p

My small contribution to this great project.

Ref: https://github.com/ggerganov/llama.cpp/pull/3841

Closes: https://github.com/abetlen/llama-cpp-python/issues/911

* Fix for negative temp (sample_softmax)
This commit is contained in:
TK-Master 2023-11-21 06:21:33 +02:00 committed by GitHub
parent 63fe1370ed
commit b8438f70b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 2 deletions

View file

@ -1046,6 +1046,8 @@ class Llama:
self, self,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.95, top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80, temp: float = 0.80,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -1108,7 +1110,10 @@ class Llama:
grammar=grammar, grammar=grammar,
) )
if temp == 0.0: if temp < 0.0:
self._ctx.sample_softmax(candidates=self._candidates)
id = self._candidates.candidates.data[0].id
elif temp == 0.0:
id = self._ctx.sample_token_greedy(candidates=self._candidates) id = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1: elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp) self._ctx.sample_temp(candidates=self._candidates, temp=temp)
@ -1130,8 +1135,9 @@ class Llama:
else: else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1) self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp) self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates) id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None: if grammar is not None:
@ -1143,6 +1149,8 @@ class Llama:
tokens: Sequence[int], tokens: Sequence[int],
top_k: int = 40, top_k: int = 40,
top_p: float = 0.95, top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80, temp: float = 0.80,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
reset: bool = True, reset: bool = True,
@ -1200,6 +1208,8 @@ class Llama:
token = self.sample( token = self.sample(
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp, temp=temp,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
@ -1298,6 +1308,8 @@ class Llama:
max_tokens: Optional[int] = 16, max_tokens: Optional[int] = 16,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
@ -1398,6 +1410,8 @@ class Llama:
prompt_tokens, prompt_tokens,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temperature, temp=temperature,
tfs_z=tfs_z, tfs_z=tfs_z,
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
@ -1772,6 +1786,8 @@ class Llama:
max_tokens: Optional[int] = 16, max_tokens: Optional[int] = 16,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
@ -1818,6 +1834,8 @@ class Llama:
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
stop=stop, stop=stop,
@ -1849,6 +1867,8 @@ class Llama:
max_tokens: int = 128, max_tokens: int = 128,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.95, top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
@ -1895,6 +1915,8 @@ class Llama:
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
stop=stop, stop=stop,
@ -1924,6 +1946,8 @@ class Llama:
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
@ -1970,6 +1994,8 @@ class Llama:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,

View file

@ -26,6 +26,8 @@ class LlamaChatCompletionHandler(Protocol):
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
@ -287,6 +289,8 @@ def register_chat_format(name: str):
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
@ -330,6 +334,8 @@ def register_chat_format(name: str):
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,
@ -579,6 +585,8 @@ def functionary_chat_handler(
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
@ -761,6 +769,8 @@ def functionary_chat_handler(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream, stream=stream,
stop=["user:", "</s>"], stop=["user:", "</s>"],
max_tokens=max_tokens, max_tokens=max_tokens,
@ -831,6 +841,8 @@ def functionary_chat_handler(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
typical_p=typical_p,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
@ -929,6 +941,8 @@ class Llava15ChatHandler:
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[ response_format: Optional[
@ -1045,6 +1059,8 @@ class Llava15ChatHandler:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream, stream=stream,
stop=stop, stop=stop,
max_tokens=max_tokens, max_tokens=max_tokens,

View file

@ -521,6 +521,14 @@ top_p_field = Field(
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
) )
min_p_field = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Sets a minimum base probability threshold for token selection.\n\n"
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
)
stop_field = Field( stop_field = Field(
default=None, default=None,
description="A list of tokens at which to stop generation. If None, no stop tokens are used.", description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel):
max_tokens: int = max_tokens_field max_tokens: int = max_tokens_field
temperature: float = temperature_field temperature: float = temperature_field
top_p: float = top_p_field top_p: float = top_p_field
min_p: float = min_p_field
echo: bool = Field( echo: bool = Field(
default=False, default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.", description="Whether to echo the prompt in the generated text. Useful for chatbots.",
@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel):
) )
temperature: float = temperature_field temperature: float = temperature_field
top_p: float = top_p_field top_p: float = top_p_field
min_p: float = min_p_field
stop: Optional[List[str]] = stop_field stop: Optional[List[str]] = stop_field
stream: bool = stream_field stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field presence_penalty: Optional[float] = presence_penalty_field