From 3ceb47b597a8819db3afa851df4ae3211f2cb680 Mon Sep 17 00:00:00 2001 From: Mug <2797716+SagsMug@users.noreply.github.com> Date: Sat, 6 May 2023 13:35:50 +0200 Subject: [PATCH] Fix mirastat requiring c_float --- .../low_level_api/low_level_api_chat_cpp.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 72ced2b..55b24cd 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -357,7 +357,7 @@ n_keep = {self.params.n_keep} # Apply params.logit_bias map for key, value in self.params.logit_bias.items(): - logits[key] += value + logits[key] += llama_cpp.c_float(value) _arr = (llama_cpp.llama_token_data * n_vocab)(*[ llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) @@ -372,14 +372,14 @@ n_keep = {self.params.n_keep} _arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:]) llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p, _arr, - last_n_repeat, self.params.repeat_penalty) + last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty)) llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p, _arr, - last_n_repeat, self.params.frequency_penalty, self.params.presence_penalty) + last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty)) if not self.params.penalize_nl: logits[llama_cpp.llama_token_nl()] = nl_logit - + if self.params.temp <= 0: # Greedy sampling id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p) @@ -387,19 +387,19 @@ n_keep = {self.params.n_keep} if self.params.mirostat == 1: mirostat_mu = 2.0 * self.params.mirostat_tau mirostat_m = 100 - llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp) - id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_m, mirostat_mu) + llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)) + id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_int(mirostat_m), llama_cpp.c_float(mirostat_mu)) elif self.params.mirostat == 2: mirostat_mu = 2.0 * self.params.mirostat_tau - llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp) - id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_mu) + llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)) + id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_float(mirostat_mu)) else: # Temperature sampling llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k) - llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, self.params.tfs_z) - llama_cpp.llama_sample_typical(self.ctx, candidates_p, self.params.typical_p) - llama_cpp.llama_sample_top_p(self.ctx, candidates_p, self.params.top_p) - llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp) + llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, llama_cpp.c_float(self.params.tfs_z)) + llama_cpp.llama_sample_typical(self.ctx, candidates_p, llama_cpp.c_float(self.params.typical_p)) + llama_cpp.llama_sample_top_p(self.ctx, candidates_p, llama_cpp.c_float(self.params.top_p)) + llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)) id = llama_cpp.llama_sample_token(self.ctx, candidates_p) # print("`{}`".format(candidates_p.size))