From 3babe3512cb95743108f2b595210c38ed6f1b904 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 08:31:59 -0500 Subject: [PATCH] Fix mirostat sampling --- llama_cpp/llama.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6cdc1eb..32eb3fe 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -329,6 +329,8 @@ class Llama: (n_ctx, self._n_vocab), dtype=np.single ) + self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context + @property def ctx(self) -> llama_cpp.llama_context_p: assert self._ctx.ctx is not None @@ -516,7 +518,7 @@ class Llama: candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=2.0 * mirostat_tau, + mu=ctypes.pointer(self._mirostat_mu), m=100, ) elif mirostat_mode == 2: @@ -525,7 +527,7 @@ class Llama: candidates=self._candidates, tau=mirostat_tau, eta=mirostat_eta, - mu=2.0 * mirostat_tau, + mu=ctypes.pointer(self._mirostat_mu) ) else: self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) @@ -581,6 +583,10 @@ class Llama: Yields: The generated tokens. """ + # Reset mirostat sampling + self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau) + + # Check for kv cache prefix match if reset and self.n_tokens > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): @@ -595,12 +601,15 @@ class Llama: tokens = tokens[longest_prefix:] self.n_tokens = longest_prefix + # Reset the model state if reset: self.reset() + # Reset the grammar if grammar is not None: grammar.reset() + # Eval and sample while True: self.eval(tokens) token = self.sample(