diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7a8c25b..32d5424 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -649,7 +649,10 @@ class Llama: self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] - all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits] + all_logprobs = [ + Llama.logits_to_logprobs(list(map(float, row))) + for row in self.eval_logits + ] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): @@ -968,7 +971,10 @@ class Llama: llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))() llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes)) if self.verbose: - print(f"Llama.save_state: saving {n_bytes} bytes of llama state", file=sys.stderr) + print( + f"Llama.save_state: saving {n_bytes} bytes of llama state", + file=sys.stderr, + ) return LlamaState( eval_tokens=self.eval_tokens.copy(), eval_logits=self.eval_logits.copy(),