Add openai frequency and presence penalty parameters. Closes #169

This commit is contained in:
Andrei Betlen 2023-05-08 01:30:18 -04:00
parent 75d8619b1a
commit 65d9cc050c
2 changed files with 36 additions and 6 deletions

View file

@ -261,7 +261,7 @@ class Llama:
] ]
self.eval_logits.extend(logits) self.eval_logits.extend(logits)
def _sample_top_p_top_k( def _sample(
self, self,
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token] last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
last_n_tokens_size: llama_cpp.c_int, last_n_tokens_size: llama_cpp.c_int,
@ -269,6 +269,8 @@ class Llama:
top_p: llama_cpp.c_float, top_p: llama_cpp.c_float,
temp: llama_cpp.c_float, temp: llama_cpp.c_float,
repeat_penalty: llama_cpp.c_float, repeat_penalty: llama_cpp.c_float,
frequency_penalty: llama_cpp.c_float,
presence_penalty: llama_cpp.c_float,
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
@ -298,6 +300,14 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
penalty=repeat_penalty, penalty=repeat_penalty,
) )
llama_cpp.llama_sample_frequency_and_presence_penalties(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
last_tokens_data=last_n_tokens_data,
last_tokens_size=last_n_tokens_size,
alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty,
)
if float(temp.value) == 0.0: if float(temp.value) == 0.0:
return llama_cpp.llama_sample_token_greedy( return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
@ -344,6 +354,8 @@ class Llama:
top_p: float, top_p: float,
temp: float, temp: float,
repeat_penalty: float, repeat_penalty: float,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -360,7 +372,7 @@ class Llama:
last_n_tokens_data = [llama_cpp.llama_token(0)] * max( last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens) 0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :] ) + list(self.eval_tokens)[-self.last_n_tokens_size :]
return self._sample_top_p_top_k( return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data *last_n_tokens_data
), ),
@ -369,6 +381,8 @@ class Llama:
top_p=llama_cpp.c_float(top_p), top_p=llama_cpp.c_float(top_p),
temp=llama_cpp.c_float(temp), temp=llama_cpp.c_float(temp),
repeat_penalty=llama_cpp.c_float(repeat_penalty), repeat_penalty=llama_cpp.c_float(repeat_penalty),
frequency_penalty=llama_cpp.c_float(frequency_penalty),
presence_penalty=llama_cpp.c_float(presence_penalty),
) )
def generate( def generate(
@ -378,6 +392,8 @@ class Llama:
top_p: float, top_p: float,
temp: float, temp: float,
repeat_penalty: float, repeat_penalty: float,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
reset: bool = True, reset: bool = True,
) -> Generator[ ) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
@ -431,6 +447,8 @@ class Llama:
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temp=temp, temp=temp,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
) )
tokens_or_none = yield token tokens_or_none = yield token
@ -505,6 +523,8 @@ class Llama:
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[List[str]] = [], stop: Optional[List[str]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -563,6 +583,8 @@ class Llama:
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temp=temperature, temp=temperature,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
): ):
if token == llama_cpp.llama_token_eos(): if token == llama_cpp.llama_token_eos():
@ -737,6 +759,8 @@ class Llama:
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[List[str]] = [], stop: Optional[List[str]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -772,6 +796,8 @@ class Llama:
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
stop=stop, stop=stop,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
top_k=top_k, top_k=top_k,
stream=stream, stream=stream,
@ -792,6 +818,8 @@ class Llama:
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
echo: bool = False, echo: bool = False,
stop: Optional[List[str]] = [], stop: Optional[List[str]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
@ -827,6 +855,8 @@ class Llama:
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
stop=stop, stop=stop,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
top_k=top_k, top_k=top_k,
stream=stream, stream=stream,
@ -899,6 +929,8 @@ class Llama:
stream: bool = False, stream: bool = False,
stop: Optional[List[str]] = [], stop: Optional[List[str]] = [],
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -932,6 +964,8 @@ class Llama:
stream=stream, stream=stream,
max_tokens=max_tokens, max_tokens=max_tokens,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

View file

@ -214,8 +214,6 @@ def create_completion(
exclude={ exclude={
"model", "model",
"n", "n",
"frequency_penalty",
"presence_penalty",
"best_of", "best_of",
"logit_bias", "logit_bias",
"user", "user",
@ -315,8 +313,6 @@ def create_chat_completion(
exclude={ exclude={
"model", "model",
"n", "n",
"presence_penalty",
"frequency_penalty",
"logit_bias", "logit_bias",
"user", "user",
} }