Make Llama instance pickleable. Closes #27

This commit is contained in:
Andrei Betlen 2023-04-05 06:52:17 -04:00
parent 152e4695c3
commit e96a5c5722
2 changed files with 56 additions and 0 deletions

View file

@ -651,6 +651,45 @@ class Llama:
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)
self.ctx = None self.ctx = None
def __getstate__(self):
return dict(
verbose=self.verbose,
model_path=self.model_path,
n_ctx=self.params.n_ctx,
n_parts=self.params.n_parts,
seed=self.params.seed,
f16_kv=self.params.f16_kv,
logits_all=self.params.logits_all,
vocab_only=self.params.vocab_only,
use_mlock=self.params.use_mlock,
embedding=self.params.embedding,
last_n_tokens_size=self.last_n_tokens_size,
last_n_tokens_data=self.last_n_tokens_data,
tokens_consumed=self.tokens_consumed,
n_batch=self.n_batch,
n_threads=self.n_threads,
)
def __setstate__(self, state):
self.__init__(
model_path=state["model_path"],
n_ctx=state["n_ctx"],
n_parts=state["n_parts"],
seed=state["seed"],
f16_kv=state["f16_kv"],
logits_all=state["logits_all"],
vocab_only=state["vocab_only"],
use_mlock=state["use_mlock"],
embedding=state["embedding"],
n_threads=state["n_threads"],
n_batch=state["n_batch"],
last_n_tokens_size=state["last_n_tokens_size"],
verbose=state["verbose"],
)
self.last_n_tokens_data=state["last_n_tokens_data"]
self.tokens_consumed=state["tokens_consumed"]
@staticmethod @staticmethod
def token_eos() -> llama_cpp.llama_token: def token_eos() -> llama_cpp.llama_token:
"""Return the end-of-sequence token.""" """Return the end-of-sequence token."""

View file

@ -77,3 +77,20 @@ def test_llama_patch(monkeypatch):
chunks = llama.create_completion(text, max_tokens=2, stream=True) chunks = llama.create_completion(text, max_tokens=2, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j" assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
assert completion["choices"][0]["finish_reason"] == "length" assert completion["choices"][0]["finish_reason"] == "length"
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
fp.seek(0)
llama = pickle.load(fp)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text