Fix mirastat requiring c_float
This commit is contained in:
parent
9797394c81
commit
3ceb47b597
1 changed files with 12 additions and 12 deletions
|
@ -357,7 +357,7 @@ n_keep = {self.params.n_keep}
|
|||
|
||||
# Apply params.logit_bias map
|
||||
for key, value in self.params.logit_bias.items():
|
||||
logits[key] += value
|
||||
logits[key] += llama_cpp.c_float(value)
|
||||
|
||||
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
|
||||
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
|
||||
|
@ -372,10 +372,10 @@ n_keep = {self.params.n_keep}
|
|||
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
|
||||
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
|
||||
_arr,
|
||||
last_n_repeat, self.params.repeat_penalty)
|
||||
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
|
||||
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
|
||||
_arr,
|
||||
last_n_repeat, self.params.frequency_penalty, self.params.presence_penalty)
|
||||
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
|
||||
|
||||
if not self.params.penalize_nl:
|
||||
logits[llama_cpp.llama_token_nl()] = nl_logit
|
||||
|
@ -387,19 +387,19 @@ n_keep = {self.params.n_keep}
|
|||
if self.params.mirostat == 1:
|
||||
mirostat_mu = 2.0 * self.params.mirostat_tau
|
||||
mirostat_m = 100
|
||||
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
|
||||
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_m, mirostat_mu)
|
||||
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
|
||||
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_int(mirostat_m), llama_cpp.c_float(mirostat_mu))
|
||||
elif self.params.mirostat == 2:
|
||||
mirostat_mu = 2.0 * self.params.mirostat_tau
|
||||
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
|
||||
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_mu)
|
||||
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
|
||||
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_float(mirostat_mu))
|
||||
else:
|
||||
# Temperature sampling
|
||||
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k)
|
||||
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, self.params.tfs_z)
|
||||
llama_cpp.llama_sample_typical(self.ctx, candidates_p, self.params.typical_p)
|
||||
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, self.params.top_p)
|
||||
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
|
||||
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, llama_cpp.c_float(self.params.tfs_z))
|
||||
llama_cpp.llama_sample_typical(self.ctx, candidates_p, llama_cpp.c_float(self.params.typical_p))
|
||||
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, llama_cpp.c_float(self.params.top_p))
|
||||
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
|
||||
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
|
||||
# print("`{}`".format(candidates_p.size))
|
||||
|
||||
|
|
Loading…
Reference in a new issue