From 2c0d9b182cd417338a85396660d9828070b3373f Mon Sep 17 00:00:00 2001 From: Mug <2797716+SagsMug@users.noreply.github.com> Date: Mon, 8 May 2023 15:27:03 +0200 Subject: [PATCH] Fix session loading and saving in low level example chat --- .../low_level_api/low_level_api_chat_cpp.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 272b454..b86d723 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -112,16 +112,17 @@ specified) expect poor results""", file=sys.stderr) if (path.exists(self.params.path_session)): _session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))() - _n_token_count_out = llama_cpp.c_int() + _n_token_count_out = llama_cpp.c_size_t() if (llama_cpp.llama_load_session_file( self.ctx, self.params.path_session.encode("utf8"), _session_tokens, self.params.n_ctx, ctypes.byref(_n_token_count_out) - ) != 0): + ) != 1): print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr) return + _n_token_count_out = _n_token_count_out.value self.session_tokens = _session_tokens[:_n_token_count_out] print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr) else: @@ -135,19 +136,21 @@ specified) expect poor results""", file=sys.stderr) raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})") # debug message about similarity of saved session, if applicable - n_matching_session_tokens = 0 + self.n_matching_session_tokens = 0 if len(self.session_tokens) > 0: for id in self.session_tokens: - if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]: + if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]: break - n_matching_session_tokens += 1 + self.n_matching_session_tokens += 1 - if n_matching_session_tokens >= len(self.embd_inp): + if self.n_matching_session_tokens >= len(self.embd_inp): print(f"session file has exact match for prompt!") - elif n_matching_session_tokens < (len(self.embd_inp) / 2): - print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated") + elif self.n_matching_session_tokens < (len(self.embd_inp) / 2): + print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated") else: - print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt") + print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt") + + self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4) # number of tokens to keep when resetting context if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct): @@ -232,9 +235,6 @@ n_keep = {self.params.n_keep} """, file=sys.stderr) self.set_color(util.CONSOLE_COLOR_PROMPT) - self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4) - - # tokenize a prompt def _tokenize(self, prompt, bos=True): _arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))() @@ -302,7 +302,7 @@ n_keep = {self.params.n_keep} ) != 0): raise Exception("Failed to llama_eval!") - if len(self.embd) > 0 and not len(self.params.path_session) > 0: + if len(self.embd) > 0 and len(self.params.path_session) > 0: self.session_tokens.extend(self.embd) self.n_session_consumed = len(self.session_tokens) @@ -319,7 +319,7 @@ n_keep = {self.params.n_keep} llama_cpp.llama_save_session_file( self.ctx, self.params.path_session.encode("utf8"), - self.session_tokens, + (llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens), len(self.session_tokens) )