Make Llama instance pickleable. Closes #27
This commit is contained in:
parent
152e4695c3
commit
e96a5c5722
2 changed files with 56 additions and 0 deletions
|
@ -651,6 +651,45 @@ class Llama:
|
|||
llama_cpp.llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(
|
||||
verbose=self.verbose,
|
||||
model_path=self.model_path,
|
||||
n_ctx=self.params.n_ctx,
|
||||
n_parts=self.params.n_parts,
|
||||
seed=self.params.seed,
|
||||
f16_kv=self.params.f16_kv,
|
||||
logits_all=self.params.logits_all,
|
||||
vocab_only=self.params.vocab_only,
|
||||
use_mlock=self.params.use_mlock,
|
||||
embedding=self.params.embedding,
|
||||
last_n_tokens_size=self.last_n_tokens_size,
|
||||
last_n_tokens_data=self.last_n_tokens_data,
|
||||
tokens_consumed=self.tokens_consumed,
|
||||
n_batch=self.n_batch,
|
||||
n_threads=self.n_threads,
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(
|
||||
model_path=state["model_path"],
|
||||
n_ctx=state["n_ctx"],
|
||||
n_parts=state["n_parts"],
|
||||
seed=state["seed"],
|
||||
f16_kv=state["f16_kv"],
|
||||
logits_all=state["logits_all"],
|
||||
vocab_only=state["vocab_only"],
|
||||
use_mlock=state["use_mlock"],
|
||||
embedding=state["embedding"],
|
||||
n_threads=state["n_threads"],
|
||||
n_batch=state["n_batch"],
|
||||
last_n_tokens_size=state["last_n_tokens_size"],
|
||||
verbose=state["verbose"],
|
||||
)
|
||||
self.last_n_tokens_data=state["last_n_tokens_data"]
|
||||
self.tokens_consumed=state["tokens_consumed"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def token_eos() -> llama_cpp.llama_token:
|
||||
"""Return the end-of-sequence token."""
|
||||
|
|
|
@ -77,3 +77,20 @@ def test_llama_patch(monkeypatch):
|
|||
chunks = llama.create_completion(text, max_tokens=2, stream=True)
|
||||
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
|
||||
assert completion["choices"][0]["finish_reason"] == "length"
|
||||
|
||||
|
||||
def test_llama_pickle():
|
||||
import pickle
|
||||
import tempfile
|
||||
fp = tempfile.TemporaryFile()
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
pickle.dump(llama, fp)
|
||||
fp.seek(0)
|
||||
llama = pickle.load(fp)
|
||||
|
||||
assert llama
|
||||
assert llama.ctx is not None
|
||||
|
||||
text = b"Hello World"
|
||||
|
||||
assert llama.detokenize(llama.tokenize(text)) == text
|
Loading…
Reference in a new issue