Add openai frequency and presence penalty parameters. Closes #169
This commit is contained in:
parent
75d8619b1a
commit
65d9cc050c
2 changed files with 36 additions and 6 deletions
|
@ -261,7 +261,7 @@ class Llama:
|
|||
]
|
||||
self.eval_logits.extend(logits)
|
||||
|
||||
def _sample_top_p_top_k(
|
||||
def _sample(
|
||||
self,
|
||||
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
|
||||
last_n_tokens_size: llama_cpp.c_int,
|
||||
|
@ -269,6 +269,8 @@ class Llama:
|
|||
top_p: llama_cpp.c_float,
|
||||
temp: llama_cpp.c_float,
|
||||
repeat_penalty: llama_cpp.c_float,
|
||||
frequency_penalty: llama_cpp.c_float,
|
||||
presence_penalty: llama_cpp.c_float,
|
||||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
|
@ -298,6 +300,14 @@ class Llama:
|
|||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||
penalty=repeat_penalty,
|
||||
)
|
||||
llama_cpp.llama_sample_frequency_and_presence_penalties(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||
last_tokens_data=last_n_tokens_data,
|
||||
last_tokens_size=last_n_tokens_size,
|
||||
alpha_frequency=frequency_penalty,
|
||||
alpha_presence=presence_penalty,
|
||||
)
|
||||
if float(temp.value) == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
|
@ -344,6 +354,8 @@ class Llama:
|
|||
top_p: float,
|
||||
temp: float,
|
||||
repeat_penalty: float,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
):
|
||||
"""Sample a token from the model.
|
||||
|
||||
|
@ -360,7 +372,7 @@ class Llama:
|
|||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||
return self._sample_top_p_top_k(
|
||||
return self._sample(
|
||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||
*last_n_tokens_data
|
||||
),
|
||||
|
@ -369,6 +381,8 @@ class Llama:
|
|||
top_p=llama_cpp.c_float(top_p),
|
||||
temp=llama_cpp.c_float(temp),
|
||||
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
||||
frequency_penalty=llama_cpp.c_float(frequency_penalty),
|
||||
presence_penalty=llama_cpp.c_float(presence_penalty),
|
||||
)
|
||||
|
||||
def generate(
|
||||
|
@ -378,6 +392,8 @@ class Llama:
|
|||
top_p: float,
|
||||
temp: float,
|
||||
repeat_penalty: float,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
reset: bool = True,
|
||||
) -> Generator[
|
||||
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
||||
|
@ -431,6 +447,8 @@ class Llama:
|
|||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
tokens_or_none = yield token
|
||||
|
@ -505,6 +523,8 @@ class Llama:
|
|||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[List[str]] = [],
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
|
@ -563,6 +583,8 @@ class Llama:
|
|||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
|
@ -737,6 +759,8 @@ class Llama:
|
|||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[List[str]] = [],
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
|
@ -772,6 +796,8 @@ class Llama:
|
|||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
stop=stop,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
|
@ -792,6 +818,8 @@ class Llama:
|
|||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[List[str]] = [],
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
|
@ -827,6 +855,8 @@ class Llama:
|
|||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
stop=stop,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
|
@ -899,6 +929,8 @@ class Llama:
|
|||
stream: bool = False,
|
||||
stop: Optional[List[str]] = [],
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
"""Generate a chat completion from a list of messages.
|
||||
|
@ -932,6 +964,8 @@ class Llama:
|
|||
stream=stream,
|
||||
max_tokens=max_tokens,
|
||||
repeat_penalty=repeat_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
|
|
|
@ -214,8 +214,6 @@ def create_completion(
|
|||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
"user",
|
||||
|
@ -315,8 +313,6 @@ def create_chat_completion(
|
|||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue