feat: Make saved state more compact on-disk (#1296)
* State load/save changes - Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)` sized array. - Difference between ~350MB and ~1500MB for example prompt with ~300 tokens (makes sense lol) - Auto-formatting changes * Back out formatting changes
This commit is contained in:
parent
9842cbf99d
commit
4924455dec
1 changed files with 5 additions and 3 deletions
|
@ -18,6 +18,7 @@ from typing import (
|
||||||
Iterator,
|
Iterator,
|
||||||
Deque,
|
Deque,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
)
|
)
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -1791,7 +1792,7 @@ class Llama:
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return LlamaState(
|
return LlamaState(
|
||||||
scores=self.scores.copy(),
|
scores=self._scores.copy(),
|
||||||
input_ids=self.input_ids.copy(),
|
input_ids=self.input_ids.copy(),
|
||||||
n_tokens=self.n_tokens,
|
n_tokens=self.n_tokens,
|
||||||
llama_state=bytes(llama_state_compact),
|
llama_state=bytes(llama_state_compact),
|
||||||
|
@ -1800,7 +1801,9 @@ class Llama:
|
||||||
|
|
||||||
def load_state(self, state: LlamaState) -> None:
|
def load_state(self, state: LlamaState) -> None:
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
self.scores = state.scores.copy()
|
# Only filling in up to `n_tokens` and then zero-ing out the rest
|
||||||
|
self.scores[: state.n_tokens, :] = state.scores.copy()
|
||||||
|
self.scores[state.n_tokens :, :] = 0.0
|
||||||
self.input_ids = state.input_ids.copy()
|
self.input_ids = state.input_ids.copy()
|
||||||
self.n_tokens = state.n_tokens
|
self.n_tokens = state.n_tokens
|
||||||
state_size = state.llama_state_size
|
state_size = state.llama_state_size
|
||||||
|
@ -1951,7 +1954,6 @@ class Llama:
|
||||||
local_dir_use_symlinks=local_dir_use_symlinks,
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = os.path.join(local_dir, filename)
|
model_path = os.path.join(local_dir, filename)
|
||||||
|
|
Loading…
Reference in a new issue