Fix mirostat sampling
This commit is contained in:
parent
141293a75b
commit
3babe3512c
1 changed files with 11 additions and 2 deletions
|
@ -329,6 +329,8 @@ class Llama:
|
|||
(n_ctx, self._n_vocab), dtype=np.single
|
||||
)
|
||||
|
||||
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
|
||||
|
||||
@property
|
||||
def ctx(self) -> llama_cpp.llama_context_p:
|
||||
assert self._ctx.ctx is not None
|
||||
|
@ -516,7 +518,7 @@ class Llama:
|
|||
candidates=self._candidates,
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=2.0 * mirostat_tau,
|
||||
mu=ctypes.pointer(self._mirostat_mu),
|
||||
m=100,
|
||||
)
|
||||
elif mirostat_mode == 2:
|
||||
|
@ -525,7 +527,7 @@ class Llama:
|
|||
candidates=self._candidates,
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=2.0 * mirostat_tau,
|
||||
mu=ctypes.pointer(self._mirostat_mu)
|
||||
)
|
||||
else:
|
||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||
|
@ -581,6 +583,10 @@ class Llama:
|
|||
Yields:
|
||||
The generated tokens.
|
||||
"""
|
||||
# Reset mirostat sampling
|
||||
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
|
||||
|
||||
# Check for kv cache prefix match
|
||||
if reset and self.n_tokens > 0:
|
||||
longest_prefix = 0
|
||||
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||
|
@ -595,12 +601,15 @@ class Llama:
|
|||
tokens = tokens[longest_prefix:]
|
||||
self.n_tokens = longest_prefix
|
||||
|
||||
# Reset the model state
|
||||
if reset:
|
||||
self.reset()
|
||||
|
||||
# Reset the grammar
|
||||
if grammar is not None:
|
||||
grammar.reset()
|
||||
|
||||
# Eval and sample
|
||||
while True:
|
||||
self.eval(tokens)
|
||||
token = self.sample(
|
||||
|
|
Loading…
Add table
Reference in a new issue