Fix mirostat sampling
This commit is contained in:
parent
141293a75b
commit
3babe3512c
1 changed files with 11 additions and 2 deletions
|
@ -329,6 +329,8 @@ class Llama:
|
||||||
(n_ctx, self._n_vocab), dtype=np.single
|
(n_ctx, self._n_vocab), dtype=np.single
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ctx(self) -> llama_cpp.llama_context_p:
|
def ctx(self) -> llama_cpp.llama_context_p:
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
|
@ -516,7 +518,7 @@ class Llama:
|
||||||
candidates=self._candidates,
|
candidates=self._candidates,
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
eta=mirostat_eta,
|
eta=mirostat_eta,
|
||||||
mu=2.0 * mirostat_tau,
|
mu=ctypes.pointer(self._mirostat_mu),
|
||||||
m=100,
|
m=100,
|
||||||
)
|
)
|
||||||
elif mirostat_mode == 2:
|
elif mirostat_mode == 2:
|
||||||
|
@ -525,7 +527,7 @@ class Llama:
|
||||||
candidates=self._candidates,
|
candidates=self._candidates,
|
||||||
tau=mirostat_tau,
|
tau=mirostat_tau,
|
||||||
eta=mirostat_eta,
|
eta=mirostat_eta,
|
||||||
mu=2.0 * mirostat_tau,
|
mu=ctypes.pointer(self._mirostat_mu)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||||
|
@ -581,6 +583,10 @@ class Llama:
|
||||||
Yields:
|
Yields:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
|
# Reset mirostat sampling
|
||||||
|
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
|
||||||
|
|
||||||
|
# Check for kv cache prefix match
|
||||||
if reset and self.n_tokens > 0:
|
if reset and self.n_tokens > 0:
|
||||||
longest_prefix = 0
|
longest_prefix = 0
|
||||||
for a, b in zip(self._input_ids, tokens[:-1]):
|
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||||
|
@ -595,12 +601,15 @@ class Llama:
|
||||||
tokens = tokens[longest_prefix:]
|
tokens = tokens[longest_prefix:]
|
||||||
self.n_tokens = longest_prefix
|
self.n_tokens = longest_prefix
|
||||||
|
|
||||||
|
# Reset the model state
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
# Reset the grammar
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
grammar.reset()
|
grammar.reset()
|
||||||
|
|
||||||
|
# Eval and sample
|
||||||
while True:
|
while True:
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
token = self.sample(
|
token = self.sample(
|
||||||
|
|
Loading…
Reference in a new issue