Implement sampling as in llama.cpp main example
This commit is contained in:
parent
93a9019bb1
commit
d957422bf4
1 changed files with 76 additions and 86 deletions
|
@ -268,14 +268,13 @@ class Llama:
|
||||||
top_k: llama_cpp.c_int,
|
top_k: llama_cpp.c_int,
|
||||||
top_p: llama_cpp.c_float,
|
top_p: llama_cpp.c_float,
|
||||||
temp: llama_cpp.c_float,
|
temp: llama_cpp.c_float,
|
||||||
mirostat_mode: llama_cpp.c_int,
|
tfs_z: llama_cpp.c_float,
|
||||||
mirostat_tau: llama_cpp.c_float,
|
|
||||||
mirostat_eta: llama_cpp.c_float,
|
|
||||||
mirostat_mu: llama_cpp.c_float,
|
|
||||||
mirostat_m: llama_cpp.c_int,
|
|
||||||
repeat_penalty: llama_cpp.c_float,
|
repeat_penalty: llama_cpp.c_float,
|
||||||
frequency_penalty: llama_cpp.c_float,
|
frequency_penalty: llama_cpp.c_float,
|
||||||
presence_penalty: llama_cpp.c_float,
|
presence_penalty: llama_cpp.c_float,
|
||||||
|
mirostat_mode: llama_cpp.c_int,
|
||||||
|
mirostat_tau: llama_cpp.c_float,
|
||||||
|
mirostat_eta: llama_cpp.c_float,
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
|
@ -305,33 +304,6 @@ class Llama:
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
penalty=repeat_penalty,
|
penalty=repeat_penalty,
|
||||||
)
|
)
|
||||||
if mirostat_mode.value == 1:
|
|
||||||
llama_cpp.llama_sample_temperature(
|
|
||||||
ctx=self.ctx,
|
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
|
||||||
temp=temp,
|
|
||||||
)
|
|
||||||
llama_cpp.llama_sample_token_mirostat(
|
|
||||||
ctx=self.ctx,
|
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
|
||||||
tau=mirostat_tau,
|
|
||||||
eta=mirostat_eta,
|
|
||||||
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
|
|
||||||
m=mirostat_m
|
|
||||||
)
|
|
||||||
elif mirostat_mode.value == 2:
|
|
||||||
llama_cpp.llama_sample_temperature(
|
|
||||||
ctx=self.ctx,
|
|
||||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
|
||||||
temp=temp,
|
|
||||||
)
|
|
||||||
llama_cpp.llama_sample_token_mirostat_v2(
|
|
||||||
ctx=self.ctx,
|
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
|
||||||
tau=mirostat_tau,
|
|
||||||
eta=mirostat_eta,
|
|
||||||
mu=llama_cpp.ctypes.byref(mirostat_mu) # type: ignore
|
|
||||||
)
|
|
||||||
llama_cpp.llama_sample_frequency_and_presence_penalties(
|
llama_cpp.llama_sample_frequency_and_presence_penalties(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
|
@ -340,11 +312,41 @@ class Llama:
|
||||||
alpha_frequency=frequency_penalty,
|
alpha_frequency=frequency_penalty,
|
||||||
alpha_presence=presence_penalty,
|
alpha_presence=presence_penalty,
|
||||||
)
|
)
|
||||||
if float(temp.value) == 0.0:
|
if temp.value == 0.0:
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
)
|
)
|
||||||
|
elif mirostat_mode.value == 1:
|
||||||
|
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
|
||||||
|
mirostat_m = llama_cpp.c_int(100)
|
||||||
|
llama_cpp.llama_sample_temperature(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
|
temp=temp,
|
||||||
|
)
|
||||||
|
return llama_cpp.llama_sample_token_mirostat(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
|
tau=mirostat_tau,
|
||||||
|
eta=mirostat_eta,
|
||||||
|
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
|
||||||
|
m=mirostat_m,
|
||||||
|
)
|
||||||
|
elif mirostat_mode.value == 2:
|
||||||
|
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
|
||||||
|
llama_cpp.llama_sample_temperature(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
|
temp=temp,
|
||||||
|
)
|
||||||
|
return llama_cpp.llama_sample_token_mirostat_v2(
|
||||||
|
ctx=self.ctx,
|
||||||
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
|
tau=mirostat_tau,
|
||||||
|
eta=mirostat_eta,
|
||||||
|
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
llama_cpp.llama_sample_top_k(
|
llama_cpp.llama_sample_top_k(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
@ -355,7 +357,7 @@ class Llama:
|
||||||
llama_cpp.llama_sample_tail_free(
|
llama_cpp.llama_sample_tail_free(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
|
||||||
z=llama_cpp.c_float(1.0),
|
z=tfs_z,
|
||||||
min_keep=llama_cpp.c_size_t(1),
|
min_keep=llama_cpp.c_size_t(1),
|
||||||
)
|
)
|
||||||
llama_cpp.llama_sample_typical(
|
llama_cpp.llama_sample_typical(
|
||||||
|
@ -382,17 +384,16 @@ class Llama:
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
top_k: int,
|
top_k: int = 40,
|
||||||
top_p: float,
|
top_p: float = 0.95,
|
||||||
temp: float,
|
temp: float = 0.80,
|
||||||
mirostat_mode: int,
|
repeat_penalty: float = 1.1,
|
||||||
mirostat_tau: float,
|
|
||||||
mirostat_eta: float,
|
|
||||||
mirostat_mu: float,
|
|
||||||
mirostat_m: int,
|
|
||||||
repeat_penalty: float,
|
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
|
tfs_z: float = 1.0,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
):
|
):
|
||||||
"""Sample a token from the model.
|
"""Sample a token from the model.
|
||||||
|
|
||||||
|
@ -417,14 +418,13 @@ class Llama:
|
||||||
top_k=llama_cpp.c_int(top_k),
|
top_k=llama_cpp.c_int(top_k),
|
||||||
top_p=llama_cpp.c_float(top_p),
|
top_p=llama_cpp.c_float(top_p),
|
||||||
temp=llama_cpp.c_float(temp),
|
temp=llama_cpp.c_float(temp),
|
||||||
mirostat_mode=llama_cpp.c_int(mirostat_mode),
|
tfs_z=llama_cpp.c_float(tfs_z),
|
||||||
mirostat_mu=llama_cpp.c_float(mirostat_mu),
|
|
||||||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
|
||||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
|
||||||
mirostat_m=llama_cpp.c_int(mirostat_m),
|
|
||||||
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
||||||
frequency_penalty=llama_cpp.c_float(frequency_penalty),
|
frequency_penalty=llama_cpp.c_float(frequency_penalty),
|
||||||
presence_penalty=llama_cpp.c_float(presence_penalty),
|
presence_penalty=llama_cpp.c_float(presence_penalty),
|
||||||
|
mirostat_mode=llama_cpp.c_int(mirostat_mode),
|
||||||
|
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||||
|
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -433,15 +433,13 @@ class Llama:
|
||||||
top_k: int,
|
top_k: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temp: float,
|
temp: float,
|
||||||
mirostat_mode: int,
|
|
||||||
mirostat_tau: float,
|
|
||||||
mirostat_eta: float,
|
|
||||||
mirostat_mu: float,
|
|
||||||
mirostat_m: int,
|
|
||||||
repeat_penalty: float,
|
repeat_penalty: float,
|
||||||
|
reset: bool = True,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
reset: bool = True,
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
|
||||||
]:
|
]:
|
||||||
|
@ -494,14 +492,12 @@ class Llama:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temp,
|
temp=temp,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
mirostat_mu=mirostat_mu,
|
|
||||||
mirostat_m=mirostat_m,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
repeat_penalty=repeat_penalty,
|
|
||||||
)
|
)
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
tokens = [token]
|
tokens = [token]
|
||||||
|
@ -571,11 +567,6 @@ class Llama:
|
||||||
suffix: Optional[str] = None,
|
suffix: Optional[str] = None,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
mirostat_mode: int = 0,
|
|
||||||
mirostat_tau: float = 5.0,
|
|
||||||
mirostat_eta: float = 0.1,
|
|
||||||
mirostat_mu: float = 10,
|
|
||||||
mirostat_m: int = 100,
|
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
|
@ -585,6 +576,9 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
|
@ -643,8 +637,6 @@ class Llama:
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
mirostat_mu=mirostat_mu,
|
|
||||||
mirostat_m=mirostat_m,
|
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
|
@ -817,11 +809,6 @@ class Llama:
|
||||||
suffix: Optional[str] = None,
|
suffix: Optional[str] = None,
|
||||||
max_tokens: int = 128,
|
max_tokens: int = 128,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
mirostat_mode: int = 0,
|
|
||||||
mirostat_tau: float = 5.0,
|
|
||||||
mirostat_eta: float = 0.1,
|
|
||||||
mirostat_mu: float = 10,
|
|
||||||
mirostat_m: int = 100,
|
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
|
@ -831,6 +818,9 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -859,11 +849,6 @@ class Llama:
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
mirostat_mode=mirostat_mode,
|
|
||||||
mirostat_tau=mirostat_tau,
|
|
||||||
mirostat_eta=mirostat_eta,
|
|
||||||
mirostat_mu=mirostat_mu,
|
|
||||||
mirostat_m=mirostat_m,
|
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
|
@ -873,6 +858,9 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||||
|
@ -886,11 +874,6 @@ class Llama:
|
||||||
suffix: Optional[str] = None,
|
suffix: Optional[str] = None,
|
||||||
max_tokens: int = 128,
|
max_tokens: int = 128,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
mirostat_mode: int = 0,
|
|
||||||
mirostat_tau: float = 5.0,
|
|
||||||
mirostat_eta: float = 0.1,
|
|
||||||
mirostat_mu: float = 10,
|
|
||||||
mirostat_m: int = 100,
|
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
|
@ -900,6 +883,9 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -928,11 +914,6 @@ class Llama:
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
mirostat_mode=mirostat_mode,
|
|
||||||
mirostat_tau=mirostat_tau,
|
|
||||||
mirostat_eta=mirostat_eta,
|
|
||||||
mirostat_mu=mirostat_mu,
|
|
||||||
mirostat_m=mirostat_m,
|
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
|
@ -942,6 +923,9 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
def _convert_text_completion_to_chat(
|
||||||
|
@ -1014,6 +998,9 @@ class Llama:
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||||
"""Generate a chat completion from a list of messages.
|
"""Generate a chat completion from a list of messages.
|
||||||
|
|
||||||
|
@ -1048,6 +1035,9 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue