Fix session loading and saving in low level example chat

This commit is contained in:
Mug 2023-05-08 15:27:03 +02:00
parent ed66a469c9
commit 2c0d9b182c

View file

@ -112,16 +112,17 @@ specified) expect poor results""", file=sys.stderr)
if (path.exists(self.params.path_session)): if (path.exists(self.params.path_session)):
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))() _session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
_n_token_count_out = llama_cpp.c_int() _n_token_count_out = llama_cpp.c_size_t()
if (llama_cpp.llama_load_session_file( if (llama_cpp.llama_load_session_file(
self.ctx, self.ctx,
self.params.path_session.encode("utf8"), self.params.path_session.encode("utf8"),
_session_tokens, _session_tokens,
self.params.n_ctx, self.params.n_ctx,
ctypes.byref(_n_token_count_out) ctypes.byref(_n_token_count_out)
) != 0): ) != 1):
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr) print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
return return
_n_token_count_out = _n_token_count_out.value
self.session_tokens = _session_tokens[:_n_token_count_out] self.session_tokens = _session_tokens[:_n_token_count_out]
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr) print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
else: else:
@ -135,19 +136,21 @@ specified) expect poor results""", file=sys.stderr)
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})") raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
# debug message about similarity of saved session, if applicable # debug message about similarity of saved session, if applicable
n_matching_session_tokens = 0 self.n_matching_session_tokens = 0
if len(self.session_tokens) > 0: if len(self.session_tokens) > 0:
for id in self.session_tokens: for id in self.session_tokens:
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]: if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
break break
n_matching_session_tokens += 1 self.n_matching_session_tokens += 1
if n_matching_session_tokens >= len(self.embd_inp): if self.n_matching_session_tokens >= len(self.embd_inp):
print(f"session file has exact match for prompt!") print(f"session file has exact match for prompt!")
elif n_matching_session_tokens < (len(self.embd_inp) / 2): elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated") print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
else: else:
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt") print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
# number of tokens to keep when resetting context # number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct): if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
@ -232,9 +235,6 @@ n_keep = {self.params.n_keep}
""", file=sys.stderr) """, file=sys.stderr)
self.set_color(util.CONSOLE_COLOR_PROMPT) self.set_color(util.CONSOLE_COLOR_PROMPT)
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
# tokenize a prompt # tokenize a prompt
def _tokenize(self, prompt, bos=True): def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))() _arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
@ -302,7 +302,7 @@ n_keep = {self.params.n_keep}
) != 0): ) != 0):
raise Exception("Failed to llama_eval!") raise Exception("Failed to llama_eval!")
if len(self.embd) > 0 and not len(self.params.path_session) > 0: if len(self.embd) > 0 and len(self.params.path_session) > 0:
self.session_tokens.extend(self.embd) self.session_tokens.extend(self.embd)
self.n_session_consumed = len(self.session_tokens) self.n_session_consumed = len(self.session_tokens)
@ -319,7 +319,7 @@ n_keep = {self.params.n_keep}
llama_cpp.llama_save_session_file( llama_cpp.llama_save_session_file(
self.ctx, self.ctx,
self.params.path_session.encode("utf8"), self.params.path_session.encode("utf8"),
self.session_tokens, (llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
len(self.session_tokens) len(self.session_tokens)
) )