This commit is contained in:
Andrei Betlen 2023-05-04 21:58:36 -04:00
parent 97c6372350
commit 853dc711cc

View file

@ -649,7 +649,10 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits]
all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits
]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
@ -968,7 +971,10 @@ class Llama:
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
if self.verbose:
print(f"Llama.save_state: saving {n_bytes} bytes of llama state", file=sys.stderr)
print(
f"Llama.save_state: saving {n_bytes} bytes of llama state",
file=sys.stderr,
)
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),