Add openai frequency and presence penalty parameters. Closes #169
This commit is contained in:
parent
75d8619b1a
commit
65d9cc050c
2 changed files with 36 additions and 6 deletions
|
@ -261,7 +261,7 @@ class Llama:
|
||||||
]
|
]
|
||||||
self.eval_logits.extend(logits)
|
self.eval_logits.extend(logits)
|
||||||
|
|
||||||
def _sample_top_p_top_k(
|
def _sample(
|
||||||
self,
|
self,
|
||||||
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
|
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
|
||||||
last_n_tokens_size: llama_cpp.c_int,
|
last_n_tokens_size: llama_cpp.c_int,
|
||||||
|
@ -269,6 +269,8 @@ class Llama:
|
||||||
top_p: llama_cpp.c_float,
|
top_p: llama_cpp.c_float,
|
||||||
temp: llama_cpp.c_float,
|
temp: llama_cpp.c_float,
|
||||||
repeat_penalty: llama_cpp.c_float,
|
repeat_penalty: llama_cpp.c_float,
|
||||||
|
frequency_penalty: llama_cpp.c_float,
|
||||||
|
presence_penalty: llama_cpp.c_float,
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
|
@ -298,6 +300,14 @@ class Llama:
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
penalty=repeat_penalty,
|
penalty=repeat_penalty,
|
||||||
)
|
)
|
||||||
|
llama_cpp.llama_sample_frequency_and_presence_penalties(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
|
last_tokens_data=last_n_tokens_data,
|
||||||
|
last_tokens_size=last_n_tokens_size,
|
||||||
|
alpha_frequency=frequency_penalty,
|
||||||
|
alpha_presence=presence_penalty,
|
||||||
|
)
|
||||||
if float(temp.value) == 0.0:
|
if float(temp.value) == 0.0:
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
@ -344,6 +354,8 @@ class Llama:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temp: float,
|
temp: float,
|
||||||
repeat_penalty: float,
|
repeat_penalty: float,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
):
|
):
|
||||||
"""Sample a token from the model.
|
"""Sample a token from the model.
|
||||||
|
|
||||||
|
@ -360,7 +372,7 @@ class Llama:
|
||||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||||
return self._sample_top_p_top_k(
|
return self._sample(
|
||||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||||
*last_n_tokens_data
|
*last_n_tokens_data
|
||||||
),
|
),
|
||||||
|
@ -369,6 +381,8 @@ class Llama:
|
||||||
top_p=llama_cpp.c_float(top_p),
|
top_p=llama_cpp.c_float(top_p),
|
||||||
temp=llama_cpp.c_float(temp),
|
temp=llama_cpp.c_float(temp),
|
||||||
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
||||||
|
frequency_penalty=llama_cpp.c_float(frequency_penalty),
|
||||||
|
presence_penalty=llama_cpp.c_float(presence_penalty),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -378,6 +392,8 @@ class Llama:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temp: float,
|
temp: float,
|
||||||
repeat_penalty: float,
|
repeat_penalty: float,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
reset: bool = True,
|
reset: bool = True,
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
||||||
|
@ -431,6 +447,8 @@ class Llama:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temp,
|
temp=temp,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
)
|
)
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
|
@ -505,6 +523,8 @@ class Llama:
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
stop: Optional[List[str]] = [],
|
stop: Optional[List[str]] = [],
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
@ -563,6 +583,8 @@ class Llama:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
):
|
):
|
||||||
if token == llama_cpp.llama_token_eos():
|
if token == llama_cpp.llama_token_eos():
|
||||||
|
@ -737,6 +759,8 @@ class Llama:
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
stop: Optional[List[str]] = [],
|
stop: Optional[List[str]] = [],
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
@ -772,6 +796,8 @@ class Llama:
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -792,6 +818,8 @@ class Llama:
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
stop: Optional[List[str]] = [],
|
stop: Optional[List[str]] = [],
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
@ -827,6 +855,8 @@ class Llama:
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -899,6 +929,8 @@ class Llama:
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[List[str]] = [],
|
stop: Optional[List[str]] = [],
|
||||||
max_tokens: int = 256,
|
max_tokens: int = 256,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||||
"""Generate a chat completion from a list of messages.
|
"""Generate a chat completion from a list of messages.
|
||||||
|
@ -932,6 +964,8 @@ class Llama:
|
||||||
stream=stream,
|
stream=stream,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
|
|
@ -214,8 +214,6 @@ def create_completion(
|
||||||
exclude={
|
exclude={
|
||||||
"model",
|
"model",
|
||||||
"n",
|
"n",
|
||||||
"frequency_penalty",
|
|
||||||
"presence_penalty",
|
|
||||||
"best_of",
|
"best_of",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
"user",
|
"user",
|
||||||
|
@ -315,8 +313,6 @@ def create_chat_completion(
|
||||||
exclude={
|
exclude={
|
||||||
"model",
|
"model",
|
||||||
"n",
|
"n",
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
"user",
|
"user",
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue