Add openai frequency and presence penalty parameters. Closes #169

This commit is contained in:
Andrei Betlen 2023-05-08 01:30:18 -04:00
parent 75d8619b1a
commit 65d9cc050c
2 changed files with 36 additions and 6 deletions

View file

@ -261,7 +261,7 @@ class Llama:
]
self.eval_logits.extend(logits)
def _sample_top_p_top_k(
def _sample(
self,
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
last_n_tokens_size: llama_cpp.c_int,
@ -269,6 +269,8 @@ class Llama:
top_p: llama_cpp.c_float,
temp: llama_cpp.c_float,
repeat_penalty: llama_cpp.c_float,
frequency_penalty: llama_cpp.c_float,
presence_penalty: llama_cpp.c_float,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
@ -298,6 +300,14 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
penalty=repeat_penalty,
)
llama_cpp.llama_sample_frequency_and_presence_penalties(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty,
)
if float(temp.value) == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
@ -344,6 +354,8 @@ class Llama:
top_p: float,
temp: float,
repeat_penalty: float,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
):
"""Sample a token from the model.
@ -360,7 +372,7 @@ class Llama:
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
return self._sample_top_p_top_k(
return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data
),
@ -369,6 +381,8 @@ class Llama:
top_p=llama_cpp.c_float(top_p),
temp=llama_cpp.c_float(temp),
repeat_penalty=llama_cpp.c_float(repeat_penalty),
frequency_penalty=llama_cpp.c_float(frequency_penalty),
presence_penalty=llama_cpp.c_float(presence_penalty),
)
def generate(
@ -378,6 +392,8 @@ class Llama:
top_p: float,
temp: float,
repeat_penalty: float,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
reset: bool = True,
) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
@ -431,6 +447,8 @@ class Llama:
top_k=top_k,
top_p=top_p,
temp=temp,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
)
tokens_or_none = yield token
@ -505,6 +523,8 @@ class Llama:
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[List[str]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
@ -563,6 +583,8 @@ class Llama:
top_k=top_k,
top_p=top_p,
temp=temperature,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
):
if token == llama_cpp.llama_token_eos():
@ -737,6 +759,8 @@ class Llama:
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[List[str]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
@ -772,6 +796,8 @@ class Llama:
logprobs=logprobs,
echo=echo,
stop=stop,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
@ -792,6 +818,8 @@ class Llama:
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[List[str]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
@ -827,6 +855,8 @@ class Llama:
logprobs=logprobs,
echo=echo,
stop=stop,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
@ -899,6 +929,8 @@ class Llama:
stream: bool = False,
stop: Optional[List[str]] = [],
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
@ -932,6 +964,8 @@ class Llama:
stream=stream,
max_tokens=max_tokens,
repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

View file

@ -214,8 +214,6 @@ def create_completion(
exclude={
"model",
"n",
"frequency_penalty",
"presence_penalty",
"best_of",
"logit_bias",
"user",
@ -315,8 +313,6 @@ def create_chat_completion(
exclude={
"model",
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}