From 4924455decd79273c8c695a8ff796306ac0df30d Mon Sep 17 00:00:00 2001 From: tc-wolf <50339167+tc-wolf@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:06:50 -0500 Subject: [PATCH] feat: Make saved state more compact on-disk (#1296) * State load/save changes - Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)` sized array. - Difference between ~350MB and ~1500MB for example prompt with ~300 tokens (makes sense lol) - Auto-formatting changes * Back out formatting changes --- llama_cpp/llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index dfac9bb..5a0111b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -18,6 +18,7 @@ from typing import ( Iterator, Deque, Callable, + Dict, ) from collections import deque from pathlib import Path @@ -1791,7 +1792,7 @@ class Llama: file=sys.stderr, ) return LlamaState( - scores=self.scores.copy(), + scores=self._scores.copy(), input_ids=self.input_ids.copy(), n_tokens=self.n_tokens, llama_state=bytes(llama_state_compact), @@ -1800,7 +1801,9 @@ class Llama: def load_state(self, state: LlamaState) -> None: assert self._ctx.ctx is not None - self.scores = state.scores.copy() + # Only filling in up to `n_tokens` and then zero-ing out the rest + self.scores[: state.n_tokens, :] = state.scores.copy() + self.scores[state.n_tokens :, :] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens state_size = state.llama_state_size @@ -1951,7 +1954,6 @@ class Llama: local_dir_use_symlinks=local_dir_use_symlinks, cache_dir=cache_dir, local_files_only=True, - ) else: model_path = os.path.join(local_dir, filename)