Added mirostat sampling to the high level API.
This commit is contained in:
parent
2f2ea00a3d
commit
aa203a0d65
1 changed files with 83 additions and 1 deletions
|
@ -257,6 +257,11 @@ class Llama:
|
|||
top_k: llama_cpp.c_int,
|
||||
top_p: llama_cpp.c_float,
|
||||
temp: llama_cpp.c_float,
|
||||
mirostat_mode: llama_cpp.c_int,
|
||||
mirostat_tau: llama_cpp.c_float,
|
||||
mirostat_eta: llama_cpp.c_float,
|
||||
mirostat_mu: llama_cpp.c_float,
|
||||
mirostat_m: llama_cpp.c_int,
|
||||
repeat_penalty: llama_cpp.c_float,
|
||||
):
|
||||
assert self.ctx is not None
|
||||
|
@ -287,7 +292,34 @@ class Llama:
|
|||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
penalty=repeat_penalty,
|
||||
)
|
||||
if float(temp.value) == 0.0:
|
||||
if mirostat_mode == 1:
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
temp=temp,
|
||||
)
|
||||
llama_cpp.llama_sample_token_mirostat(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=mirostat_mu,
|
||||
m=mirostat_m
|
||||
)
|
||||
elif mirostat_mode == 2:
|
||||
llama_cpp.llama_sample_temperature(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
temp=temp,
|
||||
)
|
||||
llama_cpp.llama_sample_token_mirostat_v2(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=mirostat_mu
|
||||
)
|
||||
elif float(temp.value) == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||
|
@ -328,6 +360,11 @@ class Llama:
|
|||
top_k: int,
|
||||
top_p: float,
|
||||
temp: float,
|
||||
mirostat_mode: int,
|
||||
mirostat_tau: float,
|
||||
mirostat_eta: float,
|
||||
mirostat_mu: float,
|
||||
mirostat_m: int,
|
||||
repeat_penalty: float,
|
||||
):
|
||||
"""Sample a token from the model.
|
||||
|
@ -353,6 +390,11 @@ class Llama:
|
|||
top_k=llama_cpp.c_int(top_k),
|
||||
top_p=llama_cpp.c_float(top_p),
|
||||
temp=llama_cpp.c_float(temp),
|
||||
mirostat=llama_cpp.c_int(mirostat_mode),
|
||||
mirostat_mu=llama_cpp.c_float(mirostat_mu),
|
||||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||
mirostat_m=llama_cpp.c_int(mirostat_m),
|
||||
repeat_penalty=llama_cpp.c_float(repeat_penalty),
|
||||
)
|
||||
|
||||
|
@ -362,6 +404,11 @@ class Llama:
|
|||
top_k: int,
|
||||
top_p: float,
|
||||
temp: float,
|
||||
mirostat: int,
|
||||
mirostat_tau: float,
|
||||
mirostat_eta: float,
|
||||
mirostat_mu: float,
|
||||
mirostat_m: int,
|
||||
repeat_penalty: float,
|
||||
reset: bool = True,
|
||||
) -> Generator[
|
||||
|
@ -416,6 +463,11 @@ class Llama:
|
|||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
mirostat_mu=mirostat_mu,
|
||||
mirostat_m=mirostat_m,
|
||||
repeat_penalty=repeat_penalty,
|
||||
)
|
||||
tokens_or_none = yield token
|
||||
|
@ -486,6 +538,11 @@ class Llama:
|
|||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 16,
|
||||
temperature: float = 0.8,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
mirostat_mu: float = 10,
|
||||
mirostat_m: int = 100,
|
||||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
|
@ -536,6 +593,11 @@ class Llama:
|
|||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temperature,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
mirostat_mu=mirostat_mu,
|
||||
mirostat_m=mirostat_m,
|
||||
repeat_penalty=repeat_penalty,
|
||||
):
|
||||
if token == llama_cpp.llama_token_eos():
|
||||
|
@ -707,6 +769,11 @@ class Llama:
|
|||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 128,
|
||||
temperature: float = 0.8,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
mirostat_mu: float = 10,
|
||||
mirostat_m: int = 100,
|
||||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
|
@ -742,6 +809,11 @@ class Llama:
|
|||
suffix=suffix,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
mirostat_mu=mirostat_mu,
|
||||
mirostat_m=mirostat_m,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
|
@ -762,6 +834,11 @@ class Llama:
|
|||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 128,
|
||||
temperature: float = 0.8,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
mirostat_mu: float = 10,
|
||||
mirostat_m: int = 100,
|
||||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
|
@ -797,6 +874,11 @@ class Llama:
|
|||
suffix=suffix,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
mirostat_mu=mirostat_mu,
|
||||
mirostat_m=mirostat_m,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
|
|
Loading…
Reference in a new issue