Fix mirostat sampling

This commit is contained in:
Andrei Betlen 2024-01-19 08:31:59 -05:00
parent 141293a75b
commit 3babe3512c

View file

@ -329,6 +329,8 @@ class Llama:
(n_ctx, self._n_vocab), dtype=np.single (n_ctx, self._n_vocab), dtype=np.single
) )
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
@property @property
def ctx(self) -> llama_cpp.llama_context_p: def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
@ -516,7 +518,7 @@ class Llama:
candidates=self._candidates, candidates=self._candidates,
tau=mirostat_tau, tau=mirostat_tau,
eta=mirostat_eta, eta=mirostat_eta,
mu=2.0 * mirostat_tau, mu=ctypes.pointer(self._mirostat_mu),
m=100, m=100,
) )
elif mirostat_mode == 2: elif mirostat_mode == 2:
@ -525,7 +527,7 @@ class Llama:
candidates=self._candidates, candidates=self._candidates,
tau=mirostat_tau, tau=mirostat_tau,
eta=mirostat_eta, eta=mirostat_eta,
mu=2.0 * mirostat_tau, mu=ctypes.pointer(self._mirostat_mu)
) )
else: else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
@ -581,6 +583,10 @@ class Llama:
Yields: Yields:
The generated tokens. The generated tokens.
""" """
# Reset mirostat sampling
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
# Check for kv cache prefix match
if reset and self.n_tokens > 0: if reset and self.n_tokens > 0:
longest_prefix = 0 longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]): for a, b in zip(self._input_ids, tokens[:-1]):
@ -595,12 +601,15 @@ class Llama:
tokens = tokens[longest_prefix:] tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix self.n_tokens = longest_prefix
# Reset the model state
if reset: if reset:
self.reset() self.reset()
# Reset the grammar
if grammar is not None: if grammar is not None:
grammar.reset() grammar.reset()
# Eval and sample
while True: while True:
self.eval(tokens) self.eval(tokens)
token = self.sample( token = self.sample(