diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a643f51..8d5df10 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -257,6 +257,11 @@ class Llama: top_k: llama_cpp.c_int, top_p: llama_cpp.c_float, temp: llama_cpp.c_float, + mirostat_mode: llama_cpp.c_int, + mirostat_tau: llama_cpp.c_float, + mirostat_eta: llama_cpp.c_float, + mirostat_mu: llama_cpp.c_float, + mirostat_m: llama_cpp.c_int, repeat_penalty: llama_cpp.c_float, ): assert self.ctx is not None @@ -287,7 +292,34 @@ class Llama: candidates=llama_cpp.ctypes.pointer(candidates), penalty=repeat_penalty, ) - if float(temp.value) == 0.0: + if mirostat_mode == 1: + llama_cpp.llama_sample_temperature( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + temp=temp, + ) + llama_cpp.llama_sample_token_mirostat( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + tau=mirostat_tau, + eta=mirostat_eta, + mu=mirostat_mu, + m=mirostat_m + ) + elif mirostat_mode == 2: + llama_cpp.llama_sample_temperature( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + temp=temp, + ) + llama_cpp.llama_sample_token_mirostat_v2( + ctx=self.ctx, + candidates=llama_cpp.ctypes.pointer(candidates), + tau=mirostat_tau, + eta=mirostat_eta, + mu=mirostat_mu + ) + elif float(temp.value) == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.pointer(candidates), @@ -328,6 +360,11 @@ class Llama: top_k: int, top_p: float, temp: float, + mirostat_mode: int, + mirostat_tau: float, + mirostat_eta: float, + mirostat_mu: float, + mirostat_m: int, repeat_penalty: float, ): """Sample a token from the model. @@ -353,6 +390,11 @@ class Llama: top_k=llama_cpp.c_int(top_k), top_p=llama_cpp.c_float(top_p), temp=llama_cpp.c_float(temp), + mirostat=llama_cpp.c_int(mirostat_mode), + mirostat_mu=llama_cpp.c_float(mirostat_mu), + mirostat_tau=llama_cpp.c_float(mirostat_tau), + mirostat_eta=llama_cpp.c_float(mirostat_eta), + mirostat_m=llama_cpp.c_int(mirostat_m), repeat_penalty=llama_cpp.c_float(repeat_penalty), ) @@ -362,6 +404,11 @@ class Llama: top_k: int, top_p: float, temp: float, + mirostat: int, + mirostat_tau: float, + mirostat_eta: float, + mirostat_mu: float, + mirostat_m: int, repeat_penalty: float, reset: bool = True, ) -> Generator[ @@ -416,6 +463,11 @@ class Llama: top_k=top_k, top_p=top_p, temp=temp, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + mirostat_mu=mirostat_mu, + mirostat_m=mirostat_m, repeat_penalty=repeat_penalty, ) tokens_or_none = yield token @@ -486,6 +538,11 @@ class Llama: suffix: Optional[str] = None, max_tokens: int = 16, temperature: float = 0.8, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + mirostat_mu: float = 10, + mirostat_m: int = 100, top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, @@ -536,6 +593,11 @@ class Llama: top_k=top_k, top_p=top_p, temp=temperature, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + mirostat_mu=mirostat_mu, + mirostat_m=mirostat_m, repeat_penalty=repeat_penalty, ): if token == llama_cpp.llama_token_eos(): @@ -707,6 +769,11 @@ class Llama: suffix: Optional[str] = None, max_tokens: int = 128, temperature: float = 0.8, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + mirostat_mu: float = 10, + mirostat_m: int = 100, top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, @@ -742,6 +809,11 @@ class Llama: suffix=suffix, max_tokens=max_tokens, temperature=temperature, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + mirostat_mu=mirostat_mu, + mirostat_m=mirostat_m, top_p=top_p, logprobs=logprobs, echo=echo, @@ -762,6 +834,11 @@ class Llama: suffix: Optional[str] = None, max_tokens: int = 128, temperature: float = 0.8, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + mirostat_mu: float = 10, + mirostat_m: int = 100, top_p: float = 0.95, logprobs: Optional[int] = None, echo: bool = False, @@ -797,6 +874,11 @@ class Llama: suffix=suffix, max_tokens=max_tokens, temperature=temperature, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + mirostat_mu=mirostat_mu, + mirostat_m=mirostat_m, top_p=top_p, logprobs=logprobs, echo=echo,