From b8438f70b58c1e4bbfb9581430ce4e032aaa3a8c Mon Sep 17 00:00:00 2001 From: TK-Master Date: Tue, 21 Nov 2023 06:21:33 +0200 Subject: [PATCH] Added support for min_p (#921) * Added support for min_p My small contribution to this great project. Ref: https://github.com/ggerganov/llama.cpp/pull/3841 Closes: https://github.com/abetlen/llama-cpp-python/issues/911 * Fix for negative temp (sample_softmax) --- llama_cpp/llama.py | 30 ++++++++++++++++++++++++++++-- llama_cpp/llama_chat_format.py | 16 ++++++++++++++++ llama_cpp/server/app.py | 10 ++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 00f2bcb..adb767f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1046,6 +1046,8 @@ class Llama: self, top_k: int = 40, top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, temp: float = 0.80, repeat_penalty: float = 1.1, frequency_penalty: float = 0.0, @@ -1108,7 +1110,10 @@ class Llama: grammar=grammar, ) - if temp == 0.0: + if temp < 0.0: + self._ctx.sample_softmax(candidates=self._candidates) + id = self._candidates.candidates.data[0].id + elif temp == 0.0: id = self._ctx.sample_token_greedy(candidates=self._candidates) elif mirostat_mode == 1: self._ctx.sample_temp(candidates=self._candidates, temp=temp) @@ -1130,8 +1135,9 @@ class Llama: else: self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) - self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1) + self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1) self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) + self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1) self._ctx.sample_temp(candidates=self._candidates, temp=temp) id = self._ctx.sample_token(candidates=self._candidates) if grammar is not None: @@ -1143,6 +1149,8 @@ class Llama: tokens: Sequence[int], top_k: int = 40, top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, temp: float = 0.80, repeat_penalty: float = 1.1, reset: bool = True, @@ -1200,6 +1208,8 @@ class Llama: token = self.sample( top_k=top_k, top_p=top_p, + min_p=min_p, + typical_p=typical_p, temp=temp, repeat_penalty=repeat_penalty, frequency_penalty=frequency_penalty, @@ -1298,6 +1308,8 @@ class Llama: max_tokens: Optional[int] = 16, temperature: float = 0.8, top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, logprobs: Optional[int] = None, echo: bool = False, stop: Optional[Union[str, List[str]]] = [], @@ -1398,6 +1410,8 @@ class Llama: prompt_tokens, top_k=top_k, top_p=top_p, + min_p=min_p, + typical_p=typical_p, temp=temperature, tfs_z=tfs_z, mirostat_mode=mirostat_mode, @@ -1772,6 +1786,8 @@ class Llama: max_tokens: Optional[int] = 16, temperature: float = 0.8, top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, logprobs: Optional[int] = None, echo: bool = False, stop: Optional[Union[str, List[str]]] = [], @@ -1818,6 +1834,8 @@ class Llama: max_tokens=max_tokens, temperature=temperature, top_p=top_p, + min_p=min_p, + typical_p=typical_p, logprobs=logprobs, echo=echo, stop=stop, @@ -1849,6 +1867,8 @@ class Llama: max_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, logprobs: Optional[int] = None, echo: bool = False, stop: Optional[Union[str, List[str]]] = [], @@ -1895,6 +1915,8 @@ class Llama: max_tokens=max_tokens, temperature=temperature, top_p=top_p, + min_p=min_p, + typical_p=typical_p, logprobs=logprobs, echo=echo, stop=stop, @@ -1924,6 +1946,8 @@ class Llama: temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, @@ -1970,6 +1994,8 @@ class Llama: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, + typical_p=typical_p, stream=stream, stop=stop, seed=seed, diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index e0b98f7..efab0b0 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -26,6 +26,8 @@ class LlamaChatCompletionHandler(Protocol): temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, @@ -287,6 +289,8 @@ def register_chat_format(name: str): temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, @@ -330,6 +334,8 @@ def register_chat_format(name: str): temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, + typical_p=typical_p, stream=stream, stop=stop, seed=seed, @@ -579,6 +585,8 @@ def functionary_chat_handler( temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, @@ -761,6 +769,8 @@ def functionary_chat_handler( temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, + typical_p=typical_p, stream=stream, stop=["user:", ""], max_tokens=max_tokens, @@ -831,6 +841,8 @@ def functionary_chat_handler( temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, + typical_p=typical_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty, @@ -929,6 +941,8 @@ class Llava15ChatHandler: temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], response_format: Optional[ @@ -1045,6 +1059,8 @@ class Llava15ChatHandler: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, + typical_p=typical_p, stream=stream, stop=stop, max_tokens=max_tokens, diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 2a6aed8..b39a462 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -521,6 +521,14 @@ top_p_field = Field( + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", ) +min_p_field = Field( + default=0.05, + ge=0.0, + le=1.0, + description="Sets a minimum base probability threshold for token selection.\n\n" + + "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.", +) + stop_field = Field( default=None, description="A list of tokens at which to stop generation. If None, no stop tokens are used.", @@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel): max_tokens: int = max_tokens_field temperature: float = temperature_field top_p: float = top_p_field + min_p: float = min_p_field echo: bool = Field( default=False, description="Whether to echo the prompt in the generated text. Useful for chatbots.", @@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel): ) temperature: float = temperature_field top_p: float = top_p_field + min_p: float = min_p_field stop: Optional[List[str]] = stop_field stream: bool = stream_field presence_penalty: Optional[float] = presence_penalty_field