Fix session loading and saving in low level example chat
This commit is contained in:
parent
ed66a469c9
commit
2c0d9b182c
1 changed files with 14 additions and 14 deletions
|
@ -112,16 +112,17 @@ specified) expect poor results""", file=sys.stderr)
|
|||
|
||||
if (path.exists(self.params.path_session)):
|
||||
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
|
||||
_n_token_count_out = llama_cpp.c_int()
|
||||
_n_token_count_out = llama_cpp.c_size_t()
|
||||
if (llama_cpp.llama_load_session_file(
|
||||
self.ctx,
|
||||
self.params.path_session.encode("utf8"),
|
||||
_session_tokens,
|
||||
self.params.n_ctx,
|
||||
ctypes.byref(_n_token_count_out)
|
||||
) != 0):
|
||||
) != 1):
|
||||
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
|
||||
return
|
||||
_n_token_count_out = _n_token_count_out.value
|
||||
self.session_tokens = _session_tokens[:_n_token_count_out]
|
||||
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
|
||||
else:
|
||||
|
@ -135,19 +136,21 @@ specified) expect poor results""", file=sys.stderr)
|
|||
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
|
||||
|
||||
# debug message about similarity of saved session, if applicable
|
||||
n_matching_session_tokens = 0
|
||||
self.n_matching_session_tokens = 0
|
||||
if len(self.session_tokens) > 0:
|
||||
for id in self.session_tokens:
|
||||
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
|
||||
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
|
||||
break
|
||||
n_matching_session_tokens += 1
|
||||
self.n_matching_session_tokens += 1
|
||||
|
||||
if n_matching_session_tokens >= len(self.embd_inp):
|
||||
if self.n_matching_session_tokens >= len(self.embd_inp):
|
||||
print(f"session file has exact match for prompt!")
|
||||
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
|
||||
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
|
||||
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
|
||||
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
|
||||
else:
|
||||
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
|
||||
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
|
||||
|
||||
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
||||
|
||||
# number of tokens to keep when resetting context
|
||||
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
|
||||
|
@ -232,9 +235,6 @@ n_keep = {self.params.n_keep}
|
|||
""", file=sys.stderr)
|
||||
self.set_color(util.CONSOLE_COLOR_PROMPT)
|
||||
|
||||
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
||||
|
||||
|
||||
# tokenize a prompt
|
||||
def _tokenize(self, prompt, bos=True):
|
||||
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
|
||||
|
@ -302,7 +302,7 @@ n_keep = {self.params.n_keep}
|
|||
) != 0):
|
||||
raise Exception("Failed to llama_eval!")
|
||||
|
||||
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
|
||||
if len(self.embd) > 0 and len(self.params.path_session) > 0:
|
||||
self.session_tokens.extend(self.embd)
|
||||
self.n_session_consumed = len(self.session_tokens)
|
||||
|
||||
|
@ -319,7 +319,7 @@ n_keep = {self.params.n_keep}
|
|||
llama_cpp.llama_save_session_file(
|
||||
self.ctx,
|
||||
self.params.path_session.encode("utf8"),
|
||||
self.session_tokens,
|
||||
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
|
||||
len(self.session_tokens)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue