Refactor internal state for Llama class
This commit is contained in:
parent
02cf881317
commit
86f8e5ad91
1 changed files with 23 additions and 40 deletions
|
@ -84,16 +84,9 @@ class Llama:
|
||||||
self.params.embedding = embedding
|
self.params.embedding = embedding
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.last_n_tokens_data = deque(
|
|
||||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
|
|
||||||
maxlen=self.last_n_tokens_size,
|
|
||||||
)
|
|
||||||
self.tokens_consumed = 0
|
|
||||||
self.tokens: List[llama_cpp.llama_token] = []
|
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.n_tokens = 0
|
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||||
self.n_past = 0
|
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
|
||||||
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
|
|
||||||
|
|
||||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||||
### saving and restoring state, this allows us to continue a completion if the last
|
### saving and restoring state, this allows us to continue a completion if the last
|
||||||
|
@ -181,14 +174,8 @@ class Llama:
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
self.last_n_tokens_data.extend(
|
self.eval_tokens.clear()
|
||||||
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
|
self.eval_logits.clear()
|
||||||
)
|
|
||||||
self.tokens_consumed = 0
|
|
||||||
self.tokens.clear()
|
|
||||||
self.n_tokens = 0
|
|
||||||
self.n_past = 0
|
|
||||||
self.all_logits.clear()
|
|
||||||
|
|
||||||
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
|
||||||
"""Evaluate a list of tokens.
|
"""Evaluate a list of tokens.
|
||||||
|
@ -200,32 +187,25 @@ class Llama:
|
||||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
|
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
||||||
self.n_tokens = len(batch)
|
n_tokens = len(batch)
|
||||||
return_code = llama_cpp.llama_eval(
|
return_code = llama_cpp.llama_eval(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
tokens=(llama_cpp.llama_token * len(batch))(*batch),
|
||||||
n_tokens=llama_cpp.c_int(self.n_tokens),
|
n_tokens=llama_cpp.c_int(n_tokens),
|
||||||
n_past=llama_cpp.c_int(self.n_past),
|
n_past=llama_cpp.c_int(n_past),
|
||||||
n_threads=llama_cpp.c_int(self.n_threads),
|
n_threads=llama_cpp.c_int(self.n_threads),
|
||||||
)
|
)
|
||||||
if int(return_code) != 0:
|
if int(return_code) != 0:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
self.tokens.extend(batch)
|
self.eval_tokens.extend(batch)
|
||||||
self.last_n_tokens_data.extend(batch)
|
|
||||||
self.tokens_consumed += len(batch)
|
|
||||||
if self.params.logits_all:
|
if self.params.logits_all:
|
||||||
self.all_logits.extend(self._logits())
|
|
||||||
|
|
||||||
def _logits(self) -> List[List[float]]:
|
|
||||||
"""Return the logits from the last call to llama_eval."""
|
|
||||||
assert self.ctx is not None
|
|
||||||
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
|
||||||
cols = int(n_vocab)
|
cols = int(n_vocab)
|
||||||
rows = self.n_tokens if self.params.logits_all else 1
|
rows = n_tokens
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
||||||
return logits
|
self.eval_logits.extend(logits)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
@ -246,10 +226,13 @@ class Llama:
|
||||||
The sampled token.
|
The sampled token.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||||
|
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||||
|
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||||
return llama_cpp.llama_sample_top_p_top_k(
|
return llama_cpp.llama_sample_top_p_top_k(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||||
*self.last_n_tokens_data
|
*last_n_tokens_data
|
||||||
),
|
),
|
||||||
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
|
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
|
||||||
top_k=llama_cpp.c_int(top_k),
|
top_k=llama_cpp.c_int(top_k),
|
||||||
|
@ -293,13 +276,13 @@ class Llama:
|
||||||
if (
|
if (
|
||||||
reset
|
reset
|
||||||
and self._cache
|
and self._cache
|
||||||
and len(self.tokens) > 0
|
and len(self.eval_tokens) > 0
|
||||||
and self.tokens == tokens[: len(self.tokens)]
|
and self.eval_tokens == tokens[: len(self.eval_tokens)]
|
||||||
):
|
):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("generate cache hit", file=sys.stderr)
|
print("generate cache hit", file=sys.stderr)
|
||||||
reset = False
|
reset = False
|
||||||
tokens = tokens[len(self.tokens) :]
|
tokens = tokens[len(self.eval_tokens) :]
|
||||||
###
|
###
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
@ -537,7 +520,7 @@ class Llama:
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
[Llama.logit_to_logprob(logit) for logit in row]
|
[Llama.logit_to_logprob(logit) for logit in row]
|
||||||
for row in self.all_logits
|
for row in self.eval_logits
|
||||||
]
|
]
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
|
|
Loading…
Reference in a new issue