Fix session loading and saving in low level example chat
This commit is contained in:
parent
ed66a469c9
commit
2c0d9b182c
1 changed files with 14 additions and 14 deletions
|
@ -112,16 +112,17 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
|
|
||||||
if (path.exists(self.params.path_session)):
|
if (path.exists(self.params.path_session)):
|
||||||
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
|
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
|
||||||
_n_token_count_out = llama_cpp.c_int()
|
_n_token_count_out = llama_cpp.c_size_t()
|
||||||
if (llama_cpp.llama_load_session_file(
|
if (llama_cpp.llama_load_session_file(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
self.params.path_session.encode("utf8"),
|
self.params.path_session.encode("utf8"),
|
||||||
_session_tokens,
|
_session_tokens,
|
||||||
self.params.n_ctx,
|
self.params.n_ctx,
|
||||||
ctypes.byref(_n_token_count_out)
|
ctypes.byref(_n_token_count_out)
|
||||||
) != 0):
|
) != 1):
|
||||||
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
|
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
_n_token_count_out = _n_token_count_out.value
|
||||||
self.session_tokens = _session_tokens[:_n_token_count_out]
|
self.session_tokens = _session_tokens[:_n_token_count_out]
|
||||||
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
|
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
|
@ -135,19 +136,21 @@ specified) expect poor results""", file=sys.stderr)
|
||||||
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
|
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
|
||||||
|
|
||||||
# debug message about similarity of saved session, if applicable
|
# debug message about similarity of saved session, if applicable
|
||||||
n_matching_session_tokens = 0
|
self.n_matching_session_tokens = 0
|
||||||
if len(self.session_tokens) > 0:
|
if len(self.session_tokens) > 0:
|
||||||
for id in self.session_tokens:
|
for id in self.session_tokens:
|
||||||
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
|
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
|
||||||
break
|
break
|
||||||
n_matching_session_tokens += 1
|
self.n_matching_session_tokens += 1
|
||||||
|
|
||||||
if n_matching_session_tokens >= len(self.embd_inp):
|
if self.n_matching_session_tokens >= len(self.embd_inp):
|
||||||
print(f"session file has exact match for prompt!")
|
print(f"session file has exact match for prompt!")
|
||||||
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
|
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
|
||||||
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
|
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
|
||||||
else:
|
else:
|
||||||
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
|
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
|
||||||
|
|
||||||
|
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
||||||
|
|
||||||
# number of tokens to keep when resetting context
|
# number of tokens to keep when resetting context
|
||||||
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
|
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
|
||||||
|
@ -232,9 +235,6 @@ n_keep = {self.params.n_keep}
|
||||||
""", file=sys.stderr)
|
""", file=sys.stderr)
|
||||||
self.set_color(util.CONSOLE_COLOR_PROMPT)
|
self.set_color(util.CONSOLE_COLOR_PROMPT)
|
||||||
|
|
||||||
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
|
|
||||||
|
|
||||||
|
|
||||||
# tokenize a prompt
|
# tokenize a prompt
|
||||||
def _tokenize(self, prompt, bos=True):
|
def _tokenize(self, prompt, bos=True):
|
||||||
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
|
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
|
||||||
|
@ -302,7 +302,7 @@ n_keep = {self.params.n_keep}
|
||||||
) != 0):
|
) != 0):
|
||||||
raise Exception("Failed to llama_eval!")
|
raise Exception("Failed to llama_eval!")
|
||||||
|
|
||||||
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
|
if len(self.embd) > 0 and len(self.params.path_session) > 0:
|
||||||
self.session_tokens.extend(self.embd)
|
self.session_tokens.extend(self.embd)
|
||||||
self.n_session_consumed = len(self.session_tokens)
|
self.n_session_consumed = len(self.session_tokens)
|
||||||
|
|
||||||
|
@ -319,7 +319,7 @@ n_keep = {self.params.n_keep}
|
||||||
llama_cpp.llama_save_session_file(
|
llama_cpp.llama_save_session_file(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
self.params.path_session.encode("utf8"),
|
self.params.path_session.encode("utf8"),
|
||||||
self.session_tokens,
|
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
|
||||||
len(self.session_tokens)
|
len(self.session_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue