Implement sampling as in llama.cpp main example

This commit is contained in:
Andrei Betlen 2023-05-08 21:21:25 -04:00
parent 93a9019bb1
commit d957422bf4

View file

@ -268,14 +268,13 @@ class Llama:
top_k: llama_cpp.c_int,
top_p: llama_cpp.c_float,
temp: llama_cpp.c_float,
mirostat_mode: llama_cpp.c_int,
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
mirostat_mu: llama_cpp.c_float,
mirostat_m: llama_cpp.c_int,
tfs_z: llama_cpp.c_float,
repeat_penalty: llama_cpp.c_float,
frequency_penalty: llama_cpp.c_float,
presence_penalty: llama_cpp.c_float,
mirostat_mode: llama_cpp.c_int,
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
@ -305,33 +304,6 @@ class Llama:
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
penalty=repeat_penalty,
)
if mirostat_mode.value == 1:
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
m=mirostat_m
)
elif mirostat_mode.value == 2:
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
temp=temp,
)
llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu) # type: ignore
)
llama_cpp.llama_sample_frequency_and_presence_penalties(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
@ -340,11 +312,41 @@ class Llama:
alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty,
)
if float(temp.value) == 0.0:
if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
)
elif mirostat_mode.value == 1:
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
mirostat_m = llama_cpp.c_int(100)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
m=mirostat_m,
)
elif mirostat_mode.value == 2:
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
llama_cpp.llama_sample_temperature(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
temp=temp,
)
return llama_cpp.llama_sample_token_mirostat_v2(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
tau=mirostat_tau,
eta=mirostat_eta,
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
)
else:
llama_cpp.llama_sample_top_k(
ctx=self.ctx,
@ -355,7 +357,7 @@ class Llama:
llama_cpp.llama_sample_tail_free(
ctx=self.ctx,
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
z=llama_cpp.c_float(1.0),
z=tfs_z,
min_keep=llama_cpp.c_size_t(1),
)
llama_cpp.llama_sample_typical(
@ -382,17 +384,16 @@ class Llama:
def sample(
self,
top_k: int,
top_p: float,
temp: float,
mirostat_mode: int,
mirostat_tau: float,
mirostat_eta: float,
mirostat_mu: float,
mirostat_m: int,
repeat_penalty: float,
top_k: int = 40,
top_p: float = 0.95,
temp: float = 0.80,
repeat_penalty: float = 1.1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
):
"""Sample a token from the model.
@ -417,14 +418,13 @@ class Llama:
top_k=llama_cpp.c_int(top_k),
top_p=llama_cpp.c_float(top_p),
temp=llama_cpp.c_float(temp),
mirostat_mode=llama_cpp.c_int(mirostat_mode),
mirostat_mu=llama_cpp.c_float(mirostat_mu),
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
mirostat_m=llama_cpp.c_int(mirostat_m),
tfs_z=llama_cpp.c_float(tfs_z),
repeat_penalty=llama_cpp.c_float(repeat_penalty),
frequency_penalty=llama_cpp.c_float(frequency_penalty),
presence_penalty=llama_cpp.c_float(presence_penalty),
mirostat_mode=llama_cpp.c_int(mirostat_mode),
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
)
def generate(
@ -433,15 +433,13 @@ class Llama:
top_k: int,
top_p: float,
temp: float,
mirostat_mode: int,
mirostat_tau: float,
mirostat_eta: float,
mirostat_mu: float,
mirostat_m: int,
repeat_penalty: float,
reset: bool = True,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
reset: bool = True,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]:
@ -494,14 +492,12 @@ class Llama:
top_k=top_k,
top_p=top_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
)
tokens_or_none = yield token
tokens = [token]
@ -571,11 +567,6 @@ class Llama:
suffix: Optional[str] = None,
max_tokens: int = 16,
temperature: float = 0.8,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
mirostat_mu: float = 10,
mirostat_m: int = 100,
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
@ -585,6 +576,9 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@ -643,8 +637,6 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
@ -817,11 +809,6 @@ class Llama:
suffix: Optional[str] = None,
max_tokens: int = 128,
temperature: float = 0.8,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
mirostat_mu: float = 10,
mirostat_m: int = 100,
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
@ -831,6 +818,9 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
@ -859,11 +849,6 @@ class Llama:
suffix=suffix,
max_tokens=max_tokens,
temperature=temperature,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
top_p=top_p,
logprobs=logprobs,
echo=echo,
@ -873,6 +858,9 @@ class Llama:
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks
@ -886,11 +874,6 @@ class Llama:
suffix: Optional[str] = None,
max_tokens: int = 128,
temperature: float = 0.8,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
mirostat_mu: float = 10,
mirostat_m: int = 100,
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
@ -900,6 +883,9 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
@ -928,11 +914,6 @@ class Llama:
suffix=suffix,
max_tokens=max_tokens,
temperature=temperature,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
mirostat_mu=mirostat_mu,
mirostat_m=mirostat_m,
top_p=top_p,
logprobs=logprobs,
echo=echo,
@ -942,6 +923,9 @@ class Llama:
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
)
def _convert_text_completion_to_chat(
@ -1014,6 +998,9 @@ class Llama:
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
@ -1048,6 +1035,9 @@ class Llama:
repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore