feat: Make saved state more compact on-disk (#1296)

* State load/save changes

- Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)`
  sized array.
  - Difference between ~350MB and ~1500MB for example prompt with ~300
    tokens (makes sense lol)
- Auto-formatting changes

* Back out formatting changes
This commit is contained in:
tc-wolf 2024-04-17 09:06:50 -05:00 committed by GitHub
parent 9842cbf99d
commit 4924455dec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -18,6 +18,7 @@ from typing import (
Iterator, Iterator,
Deque, Deque,
Callable, Callable,
Dict,
) )
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@ -1791,7 +1792,7 @@ class Llama:
file=sys.stderr, file=sys.stderr,
) )
return LlamaState( return LlamaState(
scores=self.scores.copy(), scores=self._scores.copy(),
input_ids=self.input_ids.copy(), input_ids=self.input_ids.copy(),
n_tokens=self.n_tokens, n_tokens=self.n_tokens,
llama_state=bytes(llama_state_compact), llama_state=bytes(llama_state_compact),
@ -1800,7 +1801,9 @@ class Llama:
def load_state(self, state: LlamaState) -> None: def load_state(self, state: LlamaState) -> None:
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
self.scores = state.scores.copy() # Only filling in up to `n_tokens` and then zero-ing out the rest
self.scores[: state.n_tokens, :] = state.scores.copy()
self.scores[state.n_tokens :, :] = 0.0
self.input_ids = state.input_ids.copy() self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens self.n_tokens = state.n_tokens
state_size = state.llama_state_size state_size = state.llama_state_size
@ -1951,7 +1954,6 @@ class Llama:
local_dir_use_symlinks=local_dir_use_symlinks, local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=True, local_files_only=True,
) )
else: else:
model_path = os.path.join(local_dir, filename) model_path = os.path.join(local_dir, filename)